gfn.gym¶
This module contains all the environments implemented as Gym environments.
Submodules¶
Attributes¶
Classes¶
Append-only BitSequence environment. |
|
Prepend-Append version of BitSequence env. |
|
Box environment with Cartesian per-dimension action validation. |
|
Box environment with polar (norm-based) action validation. |
|
GFlowNet environment for chip placement. |
|
HyperGrid environment with condition-aware rewards. |
|
Environment for discrete energy-based models. |
|
Environment for incrementally building graphs. |
|
Environment for building graphs edge by edge with discrete action space. |
|
HyperGrid environment from the GFlowNets paper. |
|
Mixture of Gaussians Line environment. |
|
Non-autoregressive BitSequence environment. |
|
Perfect Tree Environment. |
|
Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 |
Package Contents¶
- class gfn.gym.BitSequence(word_size=4, seq_size=120, n_modes=60, temperature=1.0, H=None, device_str='cpu', seed=0, debug=False)¶
Bases:
gfn.env.DiscreteEnvAppend-only BitSequence environment.
This environment represents a sequence of binary words and provides methods to manipulate and evaluate these sequences. The possible actions are adding binary words at once. Each binary word is represented as its decimal representation in both states and actions.
- Parameters:
word_size (int)
seq_size (int)
n_modes (int)
temperature (float)
H (Optional[torch.Tensor])
device_str (str)
seed (int)
debug (bool)
- word_size¶
The size of each binary word in the sequence.
- seq_size¶
The total number of digits of the sequence.
- n_modes¶
The number of unique modes in the sequence.
- temperature¶
The temperature parameter for reward calculation.
- H¶
A tensor used to create the modes.
- device_str¶
The device to run the computations on (“cpu” or “cuda”).
- words_per_seq¶
The number of words per sequence.
- modes¶
The set of modes written as binary.
- H = None¶
- States: type[BitSequenceStates]¶
- _backward_step(states, actions)¶
Perform a backward step in the environment by undoing the given actions to the current states.
- Parameters:
states (BitSequenceStates) – The current states of the environment.
actions (gfn.actions.Actions) – The actions to be applied to the current states.
- Returns:
The new states after performing the backward step.
- _step(states, actions)¶
Perform a step in the environment by applying the given actions to the current states.
- Parameters:
states (BitSequenceStates) – The current states of the environment.
actions (gfn.actions.Actions) – The actions to be applied to the current states.
- Returns:
The new states of the environment after applying the actions.
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (BitSequenceStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- static binary_to_integers(binary_tensor, k)¶
Convert a binary tensor to a tensor of integers.
- Parameters:
binary_tensor (torch.Tensor) – A tensor containing binary values. The tensor must be of type int64.
k (int) – The number of bits in each integer.
- Returns:
A tensor of integers obtained from the binary tensor.
- Return type:
torch.Tensor
- create_test_set(k, seed=0)¶
Create a test set by altering k times each mode a random number of bits.
Test set of size n_modes * k.
- Parameters:
k (int) – Number of variations per mode.
seed (int) – Seed for reproducibility. If None, randomness is not fixed.
- Returns:
The generated test set in the decimal representation.
- Return type:
- static hamming_distance(candidates, reference)¶
Compute the smallest edit distance from each candidate row to any reference row.
- Parameters:
candidates (torch.Tensor) – Tensor of shape (*batch_shape, length).
reference (torch.Tensor) – Tensor of shape (n_ref, length).
- Returns:
Tensor of shape (*batch_shape) containing the smallest edit distance for each candidate row.
- Return type:
torch.Tensor
- static integers_to_binary(tensor, k)¶
Convert a tensor of integers to their binary representation using k bits.
- Parameters:
tensor (torch.Tensor) – A tensor containing integers.
k (int) – The number of bits to use for the binary representation of each integer.
- Returns:
A tensor containing the binary representation of the input integers.
- Return type:
torch.Tensor
- log_reward(final_states)¶
Calculates the log-reward for the given final states.
- Parameters:
final_states (BitSequenceStates) – The final states for which to calculate the log-reward.
- Returns:
The calculated log-reward.
- Return type:
torch.Tensor
- make_modes_set(seed)¶
Generates a set of unique mode sequences based on the predefined tensor H.
- Parameters:
seed – The seed for random number generation.
- Returns:
A tensor containing the unique mode sequences.
- Raises:
ValueError – If the number of requested modes exceeds the number of possible unique sequences.
- Return type:
torch.Tensor
- make_states_class()¶
Creates a BitSequenceStates class implementation.
- Returns:
A BitSequenceStates class implementation.
- Return type:
type[BitSequenceStates]
- modes¶
- n_actions = 17¶
- n_modes: int = 60¶
- property n_states3: int¶
Returns the total number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states.
- Return type:
int
- reset(batch_shape=None, sink=False)¶
Generates initial or sink states from batch_shape.
- Parameters:
batch_shape (int | Tuple[int] | None) – The shape of the batch. If None, defaults to (1,). If an integer is provided, it is converted to a tuple.
sink (bool) – If True, sink state is created. Defaults to False.
- Returns:
The initial states of the environment after reset.
- Return type:
- reward(final_states)¶
Calculate the reward for the given final states.
The reward is computed based on the Hamming distance between the binary representation of the final states and the predefined modes. The reward is then scaled using an exponential function with a temperature parameter.
- Parameters:
final_states (BitSequenceStates) – The final states for which the reward is to be calculated.
- Returns:
The calculated reward for the given final states.
- seq_size: int = 120¶
- states_from_tensor(tensor, length=None)¶
Wraps the supplied Tensor in a States instance.
- Parameters:
tensor (torch.Tensor) – The tensor of shape state_shape representing the states.
length (Optional[torch.Tensor]) – The length of each state in the tensor.
- Returns:
An instance of States.
- Return type:
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (BitSequenceStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- temperature = 1.0¶
- property terminating_states: BitSequenceStates¶
Returns all terminating states of the environment.
- Return type:
- trajectory_from_terminating_states(terminating_states_tensor)¶
Generate trajectories from terminating states.
This works because the DAG is a tree in the append-only version of BitSequence.
- Parameters:
terminating_states_tensor (torch.Tensor) – A tensor containing the terminating states from which to generate the trajectories. The shape of the tensor should be (batch_size, words_per_seq).
- Returns:
An object containing the generated trajectories.
- Return type:
- true_dist(condition=None)¶
Returns the true probability mass function of the reward distribution.
- Return type:
torch.Tensor
- word_size: int = 4¶
- words_per_seq: int = 30¶
- class gfn.gym.BitSequencePlus(word_size=4, seq_size=120, n_modes=60, temperature=1.0, H=None, device_str='cpu', seed=0)¶
Bases:
BitSequencePrepend-Append version of BitSequence env.
This environment is similar to BitSequence, but allows to prepend and append words to the sequence.
- Parameters:
word_size (int)
seq_size (int)
n_modes (int)
temperature (float)
H (Optional[torch.Tensor])
device_str (str)
seed (int)
- H = None¶
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (BitSequenceStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- make_states_class()¶
Creates a BitSequenceStates class implementation for BitSequencePlus.
- Returns:
A BitSequenceStates class implementation with prepend-append mask logic.
- Return type:
type[BitSequenceStates]
- modes¶
- n_modes: int = 60¶
- seq_size: int = 120¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (BitSequenceStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- temperature = 1.0¶
- abstract trajectory_from_terminating_states(terminating_states_tensor)¶
Generates trajectories from terminating states. Not implemented for this environment.
- Parameters:
terminating_states_tensor (torch.Tensor) – A tensor of terminating states.
- Raises:
NotImplementedError – This method is not implemented for this environment.
- Return type:
- word_size: int = 4¶
- words_per_seq: int = 30¶
- gfn.gym.Box¶
- class gfn.gym.BoxCartesian¶
Bases:
gfn.gym.box.BoxPolarBox environment with Cartesian per-dimension action validation.
Inherits all behavior from
BoxPolar(init, step, backward_step, reward, log_partition, norm, make_random_states). Overrides onlyis_action_valid()to use per-dimension bounds instead of polar norm constraints.Use with the Cartesian estimators/distributions in
box_cartesian_utils.py.See also
BoxPolarfor the original polar norm-based variant.- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid (Cartesian per-dimension semantics).
For Cartesian actions: - Forward from s0: action[i] >= 0 and action[i] <= 1 - Forward from non-s0: action[i] >= delta and state[i] + action[i] <= 1 - Backward: state[i] - action[i] >= 0 - Backward to s0: if all resulting dims < delta, action must equal state
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to be taken.
backward (bool) – Whether the actions are backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- class gfn.gym.BoxPolar(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device='cpu', debug=False)¶
Bases:
gfn.env.EnvBox environment with polar (norm-based) action validation.
Corresponds to the environment in Section 4.1 of https://arxiv.org/abs/2301.12594
Actions are 2D vectors whose L2 norm must equal delta (for non-s0 forward steps) or be at most delta (for the initial s0 step). Use with the polar estimators/distributions in
box_polar_utils.py.See also
BoxCartesianfor a simpler per-dimension Cartesian variant.- Parameters:
delta (float)
R0 (float)
R1 (float)
R2 (float)
epsilon (float)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- delta¶
The step size.
- R0¶
The base reward.
- R1¶
The reward for being outside the first box.
- R2¶
The reward for being inside the second box.
- epsilon¶
A small value to avoid numerical issues.
- device¶
The device to use.
- Type:
Literal[“cpu”, “cuda”] | torch.device
- Return type:
torch.device
- R0 = 0.1¶
- R1 = 0.5¶
- R2 = 2.0¶
- backward_step(states, actions)¶
Backward step function for the Box environment.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The previous states as a States object.
- Return type:
- delta = 0.1¶
- epsilon = 0.0001¶
- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid (polar norm-based semantics).
For polar actions: - Forward from s0: norm(action) <= delta - Forward from non-s0: norm(action) == delta (within tolerance) - Backward: state - action >= 0 component-wise - Backward to s0: if norm(state) < delta, action must equal state
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to be taken.
backward (bool) – Whether the actions are backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states tensor of shape (*batch_shape, 2).
- Parameters:
batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A States object with random states.
- Return type:
- static norm(x)¶
Computes the L2 norm of the input tensor along the last dimension.
- Parameters:
x (torch.Tensor) – Input tensor of shape (*batch_shape, 2).
- Returns:
Normalized tensor of shape batch_shape.
- Return type:
torch.Tensor
- reward(final_states)¶
Reward is distance from the goal point.
- Parameters:
final_states (gfn.states.States) – States object representing the final states.
- Returns:
The reward tensor of shape batch_shape.
- Return type:
torch.Tensor
- step(states, actions)¶
Step function for the Box environment.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The next states as a States object.
- Return type:
- class gfn.gym.ChipDesign(netlist_file=SAMPLE_NETLIST_FILE, init_placement=SAMPLE_INIT_PLACEMENT, std_cell_placer_mode='fd', wirelength_weight=1.0, density_weight=1.0, congestion_weight=0.5, device=None, debug=False, singularity_image=None, cd_finetune=True, reward_norm=None, reward_temper=1.0, reward_norm_samples=100, cost_stats=None)¶
Bases:
gfn.env.DiscreteEnvGFlowNet environment for chip placement.
The state is a vector of length n_macros, where state[i] is the grid cell location of the i-th macro to be placed. Unplaced macros have a location of -1.
Actions are integers from 0 to n_grid_cells - 1, representing the grid cell to place the current macro on. Action n_grid_cells is the exit action.
- Parameters:
netlist_file (str)
init_placement (str)
std_cell_placer_mode (str)
wirelength_weight (float)
density_weight (float)
congestion_weight (float)
device (str | None)
debug (bool)
singularity_image (Optional[str])
cd_finetune (bool)
reward_norm (Optional[str])
reward_temper (float)
reward_norm_samples (int)
cost_stats (Optional[CostStats])
- States: type[ChipDesignStates]¶
- __del__()¶
- Return type:
None
- _apply_state_to_plc(state_tensor)¶
Applies a single state tensor to the plc object.
- Parameters:
state_tensor (torch.Tensor)
- _backward_step(states, actions)¶
Wraps parent _backward_step and updates masks.
- Parameters:
states (gfn.states.DiscreteStates)
actions (gfn.actions.Actions)
- Return type:
- _estimate_cost_stats(n_samples)¶
Estimates cost distribution from random placements.
- Parameters:
n_samples (int)
- Return type:
None
- _hard_macro_indices¶
- _normalize_cost(cost)¶
Applies reward normalization to a raw cost value.
- Parameters:
cost (float)
- Return type:
float
- _sorted_node_indices¶
- _step(states, actions)¶
Wraps parent _step and updates masks.
- Parameters:
states (gfn.states.DiscreteStates)
actions (gfn.actions.Actions)
- Return type:
- analytical_placer()¶
Places standard cells using an analytical placer.
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (ChipDesignStates)
actions (gfn.actions.Actions)
- Return type:
- cd_finetune = True¶
- close()¶
Closes the PlacementCost subprocess to free resources.
- Return type:
None
- congestion_weight = 0.5¶
- density_weight = 1.0¶
- log_reward(final_states)¶
Computes the log reward of the final states.
- Parameters:
final_states (ChipDesignStates)
- Return type:
torch.Tensor
- make_states_class()¶
Creates the ChipDesignStates class.
- Return type:
type[ChipDesignStates]
- n_grid_cells¶
- n_macros¶
- plc¶
- reset(batch_shape, random=False, sink=False, seed=None, conditions=None)¶
Resets the environment and computes initial masks.
- Parameters:
batch_shape (int | Tuple[int, Ellipsis])
random (bool)
sink (bool)
seed (Optional[int])
conditions (Optional[torch.Tensor])
- Return type:
- reward_norm = None¶
- reward_temper = 1.0¶
- std_cell_placer_mode = 'fd'¶
- step(states, actions)¶
Performs a forward step in the environment.
- Parameters:
states (ChipDesignStates)
actions (gfn.actions.Actions)
- Return type:
- update_masks(states)¶
Updates the forward and backward masks of the states.
- Parameters:
states (ChipDesignStates)
- Return type:
None
- wirelength_weight = 1.0¶
- class gfn.gym.ConditionalHyperGrid(*args, **kwargs)¶
Bases:
HyperGridHyperGrid environment with condition-aware rewards.
Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)
- _log_partition_cache: dict[torch.Tensor, float]¶
- _max_reward: float¶
- _original_reward_fn¶
- _true_dist_cache: dict[torch.Tensor, torch.Tensor]¶
- condition_dim: int = 1¶
- is_conditional: bool = True¶
- log_partition(condition)¶
Compute the log partition for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)
- Returns:
The log partition function, as a float.
- Return type:
float
- reward(states)¶
Compute rewards for the conditional environment.
A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original
- Parameters:
states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)
- Returns:
A tensor of shape (*batch_shape,) containing the rewards.
- Return type:
torch.Tensor
- sample_conditions(batch_shape)¶
Sample conditions for the environment.
- Parameters:
batch_shape (int | tuple[int, Ellipsis])
- Return type:
torch.Tensor
- true_dist(condition)¶
Compute the true distribution for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the true distribution for.
be (condition.shape should)
- Returns:
The true distribution for the given condition as a 1-dimensional tensor.
- Return type:
torch.Tensor
- class gfn.gym.DiscreteEBM(ndim, energy=None, alpha=1.0, device='cpu', debug=False)¶
Bases:
gfn.env.DiscreteEnvEnvironment for discrete energy-based models.
This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf.
The states are represented as 1d tensors of length ndim with values in {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, …, -1]. An action corresponds to replacing a -1 with a 0 or a 1. Action i in [0, ndim - 1] corresponds to replacing s[i] with 0. Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1. The last action is the exit action that is only available for complete states (those with no -1).
- Parameters:
ndim (int)
energy (EnergyFunction | None)
alpha (float)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- ndim¶
Dimension D of the sampling space {0, 1}^D.
- Type:
int
- energy¶
Energy function of the EBM.
- Type:
- alpha¶
Interaction strength the EBM.
- Type:
float
- States: type[gfn.states.DiscreteStates]¶
- property all_states: gfn.states.DiscreteStates¶
Returns all possible states of the environment.
- Return type:
- alpha = 1.0¶
- backward_step(states, actions)¶
Performs a backward step.
In this env, states are n-dim vectors. s0 is empty (represented as -1), so s0=[-1, -1, …, -1], each action is replacing a -1 with either a 0 or 1. Action i in [0, ndim-1] is replacing s[i] with 0, whereas action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. A backward action asks “what index should be set back to -1”, hence the fmod to enable wrapping of indices.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to be undone.
- Returns:
The previous states.
- Return type:
- energy: EnergyFunction = None¶
- get_states_indices(states)¶
Given that each state is of length ndim with values in {-1, 0, 1}, there are 3**ndim states, which we can label from 0 to 3**ndim - 1.
The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in {0, 1, 2}. We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.
- Parameters:
states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.
- Returns:
The states indices as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- get_terminating_states_indices(states)¶
Given that each terminating state is of length ndim with values in {0, 1}, there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.
The easiest way to map each state to a unique integer is to consider the state as a number in base 2.
- Parameters:
states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.
- Returns:
The indices of the terminating states as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- is_exit_actions(actions)¶
Determines if the actions are exit actions.
- Parameters:
actions (torch.Tensor) – tensor of actions of shape (*batch_shape, *action_shape).
- Returns:
Tensor of booleans of shape (*batch_shape).
- Return type:
torch.Tensor
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float
- log_reward(final_states)¶
The energy weighted by alpha is our log reward.
- Parameters:
final_states (gfn.states.DiscreteStates) – DiscreteStates object representing the final states.
- Returns:
The log reward as tensor of shape (*batch_shape).
- Return type:
torch.Tensor
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states tensor of shape (*batch_shape, ndim).
- Parameters:
batch_shape (Tuple) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A DiscreteStates object with random states.
- Return type:
- make_states_class()¶
Returns the DiscreteStates class for the DiscreteEBM environment.
- Return type:
- property n_states: int¶
Returns the number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states in the environment.
- Return type:
int
- ndim¶
- reward(final_states)¶
Computes the reward for a batch of final states.
- Parameters:
final_states (gfn.states.DiscreteStates) – A batch of final states.
- Returns:
A tensor of rewards.
- Return type:
torch.Tensor
- step(states, actions)¶
Performs a step.
- Parameters:
states (gfn.states.States) – States object representing the current states.
actions (gfn.actions.Actions) – Actions object representing the actions to be taken.
- Returns:
The next states as a States object.
- Return type:
- property terminating_states: gfn.states.DiscreteStates¶
Returns all terminating states of the environment.
- Return type:
- true_dist(condition=None)¶
Returns the true probability mass function of the reward distribution.
- Return type:
torch.Tensor
- class gfn.gym.GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed=True, max_nodes=None, device='cpu', s0=None, sf=None, debug=False)¶
Bases:
gfn.env.GraphEnvEnvironment for incrementally building graphs.
This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT)
- Parameters:
num_node_classes (int)
num_edge_classes (int)
state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor])
is_directed (bool)
max_nodes (int | None)
device (Literal['cpu', 'cuda'] | torch.device)
s0 (torch_geometric.data.Data | None)
sf (torch_geometric.data.Data | None)
debug (bool)
- num_node_classes¶
The number of node classes.
- num_edge_classes¶
The number of edge classes.
- state_evaluator¶
A callable that computes rewards for final states.
- is_directed¶
Whether the graph is directed.
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to undo.
- Returns:
The previous states.
- Return type:
- is_action_valid(states, actions, backward=False)¶
Check if actions are valid for the given states.
- Parameters:
states (gfn.states.GraphStates) – Current graph states.
actions (gfn.actions.GraphActions) – Actions to validate.
backward (bool) – Whether this is a backward step.
- Returns:
True if all actions are valid, False otherwise.
- Return type:
bool
- make_actions_class()¶
Returns the GraphActions class for this environment.
- Returns:
A type of a subclass of GraphActions with environment-specific functionalities.
- Return type:
type[gfn.actions.GraphActions]
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generates random states.
- Parameters:
batch_shape (Tuple) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A GraphStates object with random states.
- Return type:
- make_states_class()¶
Creates a GraphStates class for this environment.
- Return type:
type[gfn.states.GraphStates]
- max_nodes = None¶
- reward(final_states)¶
The environment’s reward given a state.
- Parameters:
final_states (gfn.states.GraphStates) – A batch of final states.
- Returns:
A tensor of shape (batch_size,) containing the rewards.
- Return type:
torch.Tensor
- state_evaluator¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to take.
- Returns:
The next states.
- Return type:
- class gfn.gym.GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug=False)¶
Bases:
GraphBuildingEnvironment for building graphs edge by edge with discrete action space.
The environment supports both directed and undirected graphs.
In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building.
The action space is discrete, with size: - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit). - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit).
- Parameters:
n_nodes (int)
state_evaluator (callable)
directed (bool)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- n_nodes¶
The number of nodes in the graph.
- Type:
int
- n_possible_edges¶
The number of possible edges.
- Type:
int
- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid.
- Parameters:
states (gfn.states.GraphStates) – The current states.
actions (gfn.actions.GraphActions) – The actions to validate.
backward (bool) – Whether the actions are backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Makes a batch of random graph states with fixed number of nodes.
- Parameters:
batch_shape (Tuple) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A GraphStates object containing random graph states.
- Return type:
- make_states_class()¶
Creates a GraphStates class for this environment.
- Return type:
type[gfn.states.GraphStates]
- n_nodes¶
- class gfn.gym.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)¶
Bases:
gfn.env.DiscreteEnvHyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.
- Parameters:
ndim (int)
height (int)
reward_fn_str (str)
reward_fn_kwargs (dict | None)
device (Literal['cpu', 'cuda'] | torch.device)
calculate_partition (bool)
store_all_states (bool)
debug (bool)
validate_modes (bool)
mode_stats (Literal['none', 'approx', 'exact'])
mode_stats_samples (int)
- ndim¶
The dimension of the grid.
- height¶
The height of the grid.
- reward_fn¶
The reward function.
- calculate_partition¶
Whether to calculate the log partition function.
- store_all_states¶
Whether to store all states.
- validate_modes¶
Whether to check that at least one state reaches the mode threshold at init; raises if not.
- mode_stats¶
One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.
- mode_stats_samples¶
Number of random samples when mode_stats=”approx”.
- States: type[gfn.states.DiscreteStates]¶
- _all_states_tensor = None¶
- _enumerate_all_states_tensor(batch_size=20000)¶
Enumerate all grid states, optionally storing them and computing log Z.
Iterates over the full Cartesian product
{0, ..., H-1}^Din batches (via multiprocessing) to avoid materializing allH^Dstates at once.- Parameters:
batch_size (int) – Number of states per batch.
- _exists_bitwise_xor(thr)¶
Feasibility and constructive check for
BitwiseXORReward.Steps: - For each tier, verify the GF(2) parity system has at least one
solution using Gaussian elimination modulo 2. If any tier is infeasible, no mode exists.
The all-zero configuration satisfies even-parity constraints, so if tiers are feasible we evaluate that point against the threshold with tolerance.
- Parameters:
thr (float)
- Return type:
bool
- _exists_cosine(thr)¶
Analytic upper-bound check for
CosineReward.Idea: - The per-dimension factor is
(cos(50·ax) + 1) · N(0,1)(5·ax)withax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value
m.The full reward upper bound is
R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by
mode_gamma.
- Parameters:
thr (float)
- Return type:
bool
- _exists_fallback_random(thr)¶
Random sampling fallback.
Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.
- Parameters:
thr (float)
- Return type:
bool
- _exists_multiplicative_coprime(thr)¶
Number-theoretic constructive check for
MultiplicativeCoprimeReward.Constructs a candidate state by factoring the target LCM (if any) over the allowed primes, assigning each prime power to a separate active dimension, and verifying coprimality and grid-bound constraints.
- Parameters:
thr (float)
- Return type:
bool
- _exists_original_or_deceptive(thr)¶
Constructive check for
OriginalRewardandDeceptiveReward.Intuition: - These rewards form rings/bands around the center when each coordinate
is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.
We translate those fractional band boundaries into integer indices via small inside/outside nudges (using
EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.If the reward at that candidate exceeds the threshold (with
EPS_REWARD_CMPtolerance), we return True.
- Parameters:
thr (float)
- Return type:
bool
- _exists_sparse(thr)¶
Constructive check for
SparseReward.This reward assigns positive mass only to a finite set of target configurations. When
H>=2andD>=1, a known target is the zero vector except for certain coordinates fixed at 1 orH-2. We probe a canonical target and confirm the threshold is not above its reward.- Parameters:
thr (float)
- Return type:
bool
- _generate_combinations_in_batches(ndim, max_val, batch_size)¶
Yield batches of the Cartesian product {0, …, max_val}^ndim.
Uses multiprocessing to avoid materializing the full product (size
(max_val+1)^ndim) in memory.- Parameters:
ndim (int) – Number of dimensions (tuple length).
max_val (int) – Maximum coordinate value (inclusive).
batch_size (int) – Number of tuples per batch.
- Yields:
An iterator of tuples for each batch.
- _log_partition = None¶
- _mode_reward_threshold()¶
Returns the reward threshold used to define a mode.
By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.
- Return type:
float
- _mode_stats_kind: str = 'none'¶
- _modes_exist_quick_check()¶
Lightweight check that a mode-level state exists.
In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly
and checks against the threshold.
Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.
As a last resort, it samples a small batch of random states.
- Return type:
bool
- _modes_exist_quick_check_info()¶
Same as _modes_exist_quick_check but returns (ok, message).
- Return type:
tuple[bool, str]
- _n_mode_states_estimate: float | None = None¶
- _n_mode_states_exact: int | None = None¶
- static _solve_gf2_has_solution(A, c)¶
Return True if A x = c over GF(2) has at least one solution.
Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (
0 = 1) indicates inconsistency.- Parameters:
A (torch.Tensor)
c (torch.Tensor)
- Return type:
bool
- _true_dist = None¶
- _worker(task)¶
Return a slice of the Cartesian product for one batch.
- Parameters:
task (tuple) – (values, ndim, start_idx, end_idx) where values is the list of coordinate values, ndim is the number of dimensions, and [start_idx, end_idx) is the range within the full product.
- Return type:
itertools.islice
- all_indices()¶
Generate all possible indices for the grid.
- Returns:
A list of all possible indices for the grid.
- Return type:
List[Tuple[int, Ellipsis]]
- property all_states: gfn.states.DiscreteStates | None¶
Returns a tensor of all hypergrid states as a DiscreteStates instance.
- Return type:
gfn.states.DiscreteStates | None
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- calculate_partition = False¶
- get_states_indices(states)¶
Get the indices of the states in the canonical ordering.
- Parameters:
states (gfn.states.DiscreteStates | torch.Tensor) – The states to get the indices of.
- Returns:
The indices of the states in the canonical ordering.
- Return type:
torch.Tensor
- get_terminating_states_indices(states)¶
Get the indices of the terminating states in the canonical ordering.
- Parameters:
states (gfn.states.DiscreteStates) – The states to get the indices of.
- Returns:
The indices of the terminating states in the canonical ordering.
- Return type:
torch.Tensor
- height = 8¶
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float | None
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a batch of random states.
- Parameters:
batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A DiscreteStates object with random states.
- Return type:
- make_states_class()¶
Returns the DiscreteStates class for the HyperGrid environment.
- Return type:
- mode_mask(states)¶
Boolean mask indicating which states are in a mode.
A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).
- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
torch.Tensor
- modes_found(states)¶
Returns the set of canonical state indices for mode states in the batch.
Each mode state is identified by its unique canonical index (from
get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
set[int]
- property n_mode_states: int | float | None¶
Number of states inside a mode (exact, approx, or None).
If mode_stats=”exact”, returns an exact integer count.
If mode_stats=”approx”, returns a floating-point estimate.
If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.
Otherwise, returns None.
- Return type:
int | float | None
- property n_modes: int | float | None¶
Returns the total number of mode states for this environment.
Equivalent to
n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.- Return type:
int | float | None
- property n_states: int¶
Returns the number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states in the environment.
- Return type:
int
- ndim = 2¶
- reward(states)¶
Computes the reward for a batch of final states.
In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}
0.5 rightrvert in (0.25, 0.5] right)
2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`
- Parameters:
final_states – The final states.
states (gfn.states.DiscreteStates)
- Returns:
The reward of the final states.
- Return type:
torch.Tensor
- reward_fn¶
- reward_fn_kwargs = None¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- store_all_states = False¶
- property terminating_states: gfn.states.DiscreteStates | None¶
Returns all terminating states of the environment.
- Return type:
gfn.states.DiscreteStates | None
- true_dist(condition=None)¶
Returns the pmf over all states in the hypergrid.
- Return type:
torch.Tensor | None
- class gfn.gym.Line(mus, sigmas, init_value, n_sd=4.5, n_steps_per_trajectory=5, device='cpu', debug=False)¶
Bases:
gfn.env.EnvMixture of Gaussians Line environment.
- Parameters:
mus (list)
sigmas (list)
init_value (float)
n_sd (float)
n_steps_per_trajectory (int)
device (Literal['cpu', 'cuda'] | torch.device)
debug (bool)
- mus¶
The means of the Gaussians.
- sigmas¶
The standard deviations of the Gaussians.
- n_sd¶
The number of standard deviations to consider for the bounds.
- n_steps_per_trajectory¶
The number of steps per trajectory.
- mixture¶
The mixture of Gaussians.
- init_value¶
The initial value of the state.
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- init_value¶
- is_action_valid(states, actions, backward=False)¶
Checks if the actions are valid.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to check.
backward (bool) – Whether to check for backward actions.
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
torch.Tensor
- log_reward(final_states)¶
Computes the log reward of the environment.
- Parameters:
final_states (gfn.states.States) – The final states of the environment.
- Returns:
The log reward.
- Return type:
torch.Tensor
- mixture¶
- mus¶
- n_sd = 4.5¶
- n_steps_per_trajectory = 5¶
- sigmas¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- class gfn.gym.NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2, reward_exponent=2.0, H=None, device_str='cpu', seed=0, debug=False)¶
Bases:
gfn.env.DiscreteEnvNon-autoregressive BitSequence environment.
In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled.
The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target “mode” sequences.
- Parameters:
word_size (int) – Number of bits per word (e.g., 1 for single-bit actions).
seq_size (int) – Total number of bits in the sequence. Must be divisible by
word_size.n_modes (int) – Number of target mode sequences.
reward_exponent (float) – Controls reward sharpness. Higher values make the reward more peaked around the modes.
H (Optional[torch.Tensor]) – Optional tensor of shape
(n_modes, seq_size)specifying the target modes in binary. If None, modes are generated randomly using block patterns.device_str (str) – Device to use (
"cpu"or"cuda").seed (int) – Random seed for mode generation.
debug (bool) – If True, enable runtime guards (not compile-friendly).
- word_size¶
Number of bits per word.
- seq_size¶
Total number of bits.
- words_per_seq¶
Number of word positions (
seq_size // word_size).
- n_words¶
Number of possible word values (
2 ** word_size).
- n_modes¶
Number of target modes.
- reward_exponent¶
Reward sharpness parameter.
- modes¶
Target mode sequences as a binary tensor of shape
(n_modes, seq_size).
Example
>>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2) >>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions >>> env.n_actions 9 >>> # State shape: 4 word positions >>> env.s0 tensor([-1, -1, -1, -1])
- H = None¶
- States: type[NonAutoregressiveBitSequenceStates]¶
- _decode_action(action)¶
Decode a flat action index into (position, word) pair.
- Parameters:
action (torch.Tensor) – Action tensor of shape
(*batch_shape, 1).- Returns:
Tuple of (position, word) tensors, each of shape
(*batch_shape, 1).- Return type:
Tuple[torch.Tensor, torch.Tensor]
- static _integers_to_binary(tensor, k)¶
Convert a tensor of word integers to their binary representation.
- Parameters:
tensor (torch.Tensor) – Integer tensor of shape
(*batch_shape, words_per_seq)with values in{0, ..., 2^k - 1}.k (int) – Number of bits per word.
- Returns:
Binary tensor of shape
(*batch_shape, words_per_seq * k)with values in{0, 1}.- Return type:
torch.Tensor
- _make_modes(seed, device)¶
Generate target mode sequences in binary representation.
If
His provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper.- Parameters:
seed (int) – Random seed.
device (torch.device) – Device to place the modes tensor on.
- Returns:
Binary tensor of shape
(n_modes, seq_size)with values in {0, 1}.- Return type:
torch.Tensor
- static _min_hamming_distance(candidates, references)¶
Compute minimum Hamming distance from each candidate to any reference.
- Parameters:
candidates (torch.Tensor) – Binary tensor of shape
(*batch_shape, seq_size).references (torch.Tensor) – Binary tensor of shape
(n_refs, seq_size).
- Returns:
Tensor of shape
(*batch_shape,)with the minimum distance.- Return type:
torch.Tensor
- backward_step(states, actions)¶
Undo a word placement by clearing the position back to -1.
The backward action has the same encoding as the forward action:
action = position * n_words + word. The word component is used to identify which position to clear.- Parameters:
states (NonAutoregressiveBitSequenceStates) – Current states.
actions (gfn.actions.Actions) – Backward actions to undo.
- Returns:
Previous states with the specified positions cleared.
- Return type:
- log_reward(final_states)¶
Compute log-reward based on Hamming distance to nearest mode.
- The log-reward is:
log R(x) = -reward_exponent * min_d(x, modes) / seq_size
where
min_dis the minimum bit-level Hamming distance between the completed sequence and any target mode.- Parameters:
final_states (NonAutoregressiveBitSequenceStates) – Terminal states with all positions filled.
- Returns:
Log-reward tensor of shape
(*batch_shape,).- Return type:
torch.Tensor
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Generate random partially-filled states.
Each position is independently either unfilled (-1) or filled with a random word value.
- Parameters:
batch_shape (Tuple) – Shape of the batch.
conditions (Optional[torch.Tensor]) – Optional conditions tensor.
device (Optional[torch.device]) – Device to use.
debug (bool) – If True, enable debug mode.
- Returns:
Random states.
- Return type:
- make_states_class()¶
Create the States class with environment-specific constants.
- Return type:
- modes¶
- n_modes_count = 2¶
- property n_terminating_states: int¶
Total number of possible terminal states.
- Return type:
int
- n_words = 2¶
- reward(final_states)¶
Compute reward as
exp(log_reward).- Parameters:
final_states (NonAutoregressiveBitSequenceStates) – Terminal states.
- Returns:
Reward tensor of shape
(*batch_shape,).- Return type:
torch.Tensor
- reward_exponent = 2.0¶
- seq_size = 4¶
- step(states, actions)¶
Place a word at the specified position.
The action encodes
(position, word)as a flat index:action = position * n_words + word.- Parameters:
states (NonAutoregressiveBitSequenceStates) – Current states.
actions (gfn.actions.Actions) – Actions encoding (position, word) pairs.
- Returns:
Next states with the specified positions filled.
- Return type:
- property terminating_states: NonAutoregressiveBitSequenceStates¶
Enumerate all terminal states (only feasible for small environments).
- Return type:
- true_dist(condition=None)¶
Compute the true reward distribution over all terminal states.
- Return type:
torch.Tensor
- word_size = 1¶
- words_per_seq = 4¶
- class gfn.gym.PerfectBinaryTree(reward_fn, depth=4, device=None, debug=False)¶
Bases:
gfn.env.DiscreteEnvPerfect Tree Environment.
This environment is a perfect binary tree, where there is a bijection between trajectories and terminating states. Nodes are represented by integers, starting from 0 for the root. States are represented by a single integer tensor corresponding to the node index. Actions are integers: 0 (left child), 1 (right child), 2 (exit).
e.g.:
0 (root)
/
1 2
/ /
3 4 5 6
/ / / /
7 8 9 10 11 12 13 14 (terminating states if depth=3)
Recommended preprocessor: OneHotPreprocessor.
- Parameters:
reward_fn (Callable)
depth (int)
device (Literal['cpu', 'cuda'] | torch.device | None)
debug (bool)
- reward_fn¶
A function that computes the reward for a given state.
- Type:
Callable
- depth¶
The depth of the tree.
- Type:
int
- branching_factor¶
The branching factor of the tree.
- Type:
int
- n_actions¶
The number of actions.
- Type:
int
- n_nodes¶
The number of nodes in the tree.
- Type:
int
- transition_table¶
A dictionary that maps (state, action) to the next state.
- Type:
dict
- inverse_transition_table¶
A dictionary that maps (state, action) to the previous state.
- Type:
dict
- term_states¶
The terminating states.
- Type:
- States: type[gfn.env.DiscreteStates]¶
- _build_tree()¶
Builds the tree and the transition tables.
- Returns:
A tuple containing the transition table, the inverse transition table, and the terminating states.
- Return type:
tuple[dict, dict, gfn.env.DiscreteStates]
- property all_states: gfn.env.DiscreteStates¶
Returns all the states of the environment.
- Return type:
gfn.env.DiscreteStates
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.env.DiscreteStates) – The current states.
actions (gfn.env.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
gfn.env.DiscreteStates
- branching_factor = 2¶
- depth = 4¶
- get_states_indices(states)¶
Returns the indices of the states.
- Parameters:
states (gfn.states.States) – The states to get the indices of.
- Returns:
The indices of the states.
- make_states_class()¶
Returns the DiscreteStates class for the PerfectBinaryTree environment.
- Return type:
type[gfn.env.DiscreteStates]
- n_actions = 3¶
- n_nodes = 31¶
- reward(final_states)¶
Computes the reward for a batch of final states.
- Parameters:
final_states – The final states.
- Returns:
The reward of the final states.
- reward_fn¶
- s0¶
- sf¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.env.DiscreteStates) – The current states.
actions (gfn.env.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
gfn.env.DiscreteStates
- property terminating_states: gfn.env.DiscreteStates¶
Returns the terminating states of the environment.
- Return type:
gfn.env.DiscreteStates
- class gfn.gym.SetAddition(n_items, max_items, reward_fn, fixed_length=False, device=None, debug=False)¶
Bases:
gfn.env.DiscreteEnvAppend only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html)
The state is a binary vector of length n_items, where 1 indicates the presence of an item. Actions are integers from 0 to n_items - 1 to add the corresponding item, or n_items to exit. Adding an existing item is invalid. The trajectory must end when max_items are present.
Recommended preprocessor: IdentityPreprocessor.
- Parameters:
n_items (int)
max_items (int)
reward_fn (Callable)
fixed_length (bool)
device (Literal['cpu', 'cuda'] | torch.device | None)
debug (bool)
- n_items¶
The number of items in the set.
- Type:
int
- max_items¶
The maximum number of items that can be added to the set.
- Type:
int
- reward_fn¶
The reward function.
- Type:
Callable
- fixed_length¶
Whether the trajectories have a fixed length.
- Type:
bool
- States: type[gfn.env.DiscreteStates]¶
- property all_states: gfn.env.DiscreteStates¶
Returns all the states of the environment.
- Return type:
gfn.env.DiscreteStates
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.env.DiscreteStates) – The current states.
actions (gfn.env.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
gfn.env.DiscreteStates
- fixed_length = False¶
- get_states_indices(states)¶
Returns the indices of the states.
- Parameters:
states (gfn.env.DiscreteStates) – The states to get the indices of.
- Returns:
The indices of the states.
- make_states_class()¶
Returns the DiscreteStates class for the SetAddition environment.
- Return type:
type[gfn.env.DiscreteStates]
- max_traj_len¶
- n_items¶
- reward(final_states)¶
Computes the reward for a batch of final states.
- Parameters:
final_states (gfn.env.DiscreteStates) – The final states.
- Returns:
The reward of the final states.
- Return type:
torch.Tensor
- reward_fn¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.env.DiscreteStates) – The current states.
actions (gfn.env.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
gfn.env.DiscreteStates
- property terminating_states: gfn.env.DiscreteStates¶
Returns the terminating states of the environment.
- Return type:
gfn.env.DiscreteStates