gfn.gym.set_addition

Classes

SetAddition

Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023

Module Contents

class gfn.gym.set_addition.SetAddition(n_items, max_items, reward_fn, fixed_length=False, device=None, debug=False)

Bases: gfn.env.DiscreteEnv

Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html)

The state is a binary vector of length n_items, where 1 indicates the presence of an item. Actions are integers from 0 to n_items - 1 to add the corresponding item, or n_items to exit. Adding an existing item is invalid. The trajectory must end when max_items are present.

Recommended preprocessor: IdentityPreprocessor.

Parameters:
  • n_items (int)

  • max_items (int)

  • reward_fn (Callable)

  • fixed_length (bool)

  • device (Literal['cpu', 'cuda'] | torch.device | None)

  • debug (bool)

n_items

The number of items in the set.

Type:

int

max_items

The maximum number of items that can be added to the set.

Type:

int

reward_fn

The reward function.

Type:

Callable

fixed_length

Whether the trajectories have a fixed length.

Type:

bool

States: type[gfn.env.DiscreteStates]
property all_states: gfn.env.DiscreteStates

Returns all the states of the environment.

Return type:

gfn.env.DiscreteStates

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The previous states.

Return type:

gfn.env.DiscreteStates

fixed_length = False
get_states_indices(states)

Returns the indices of the states.

Parameters:

states (gfn.env.DiscreteStates) – The states to get the indices of.

Returns:

The indices of the states.

make_states_class()

Returns the DiscreteStates class for the SetAddition environment.

Return type:

type[gfn.env.DiscreteStates]

max_traj_len
n_items
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states (gfn.env.DiscreteStates) – The final states.

Returns:

The reward of the final states.

Return type:

torch.Tensor

reward_fn
step(states, actions)

Performs a step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The next states.

Return type:

gfn.env.DiscreteStates

property terminating_states: gfn.env.DiscreteStates

Returns the terminating states of the environment.

Return type:

gfn.env.DiscreteStates