gfn.gym

This module contains all the environments implemented as Gym environments.

Submodules

Attributes

Box

Classes

BitSequence

Append-only BitSequence environment.

BitSequencePlus

Prepend-Append version of BitSequence env.

BoxCartesian

Box environment with Cartesian per-dimension action validation.

BoxPolar

Box environment with polar (norm-based) action validation.

ChipDesign

GFlowNet environment for chip placement.

ConditionalHyperGrid

HyperGrid environment with condition-aware rewards.

DiscreteEBM

Environment for discrete energy-based models.

GraphBuilding

Environment for incrementally building graphs.

GraphBuildingOnEdges

Environment for building graphs edge by edge with discrete action space.

HyperGrid

HyperGrid environment from the GFlowNets paper.

Line

Mixture of Gaussians Line environment.

NonAutoregressiveBitSequence

Non-autoregressive BitSequence environment.

PerfectBinaryTree

Perfect Tree Environment.

SetAddition

Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023

Package Contents

class gfn.gym.BitSequence(word_size=4, seq_size=120, n_modes=60, temperature=1.0, H=None, device_str='cpu', seed=0, debug=False)

Bases: gfn.env.DiscreteEnv

Append-only BitSequence environment.

This environment represents a sequence of binary words and provides methods to manipulate and evaluate these sequences. The possible actions are adding binary words at once. Each binary word is represented as its decimal representation in both states and actions.

Parameters:
  • word_size (int)

  • seq_size (int)

  • n_modes (int)

  • temperature (float)

  • H (Optional[torch.Tensor])

  • device_str (str)

  • seed (int)

  • debug (bool)

word_size

The size of each binary word in the sequence.

seq_size

The total number of digits of the sequence.

n_modes

The number of unique modes in the sequence.

temperature

The temperature parameter for reward calculation.

H

A tensor used to create the modes.

device_str

The device to run the computations on (“cpu” or “cuda”).

words_per_seq

The number of words per sequence.

modes

The set of modes written as binary.

H = None
States: type[BitSequenceStates]
_backward_step(states, actions)

Perform a backward step in the environment by undoing the given actions to the current states.

Parameters:
Returns:

The new states after performing the backward step.

_step(states, actions)

Perform a step in the environment by applying the given actions to the current states.

Parameters:
Returns:

The new states of the environment after applying the actions.

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

BitSequenceStates

static binary_to_integers(binary_tensor, k)

Convert a binary tensor to a tensor of integers.

Parameters:
  • binary_tensor (torch.Tensor) – A tensor containing binary values. The tensor must be of type int64.

  • k (int) – The number of bits in each integer.

Returns:

A tensor of integers obtained from the binary tensor.

Return type:

torch.Tensor

create_test_set(k, seed=0)

Create a test set by altering k times each mode a random number of bits.

Test set of size n_modes * k.

Parameters:
  • k (int) – Number of variations per mode.

  • seed (int) – Seed for reproducibility. If None, randomness is not fixed.

Returns:

The generated test set in the decimal representation.

Return type:

BitSequenceStates

static hamming_distance(candidates, reference)

Compute the smallest edit distance from each candidate row to any reference row.

Parameters:
  • candidates (torch.Tensor) – Tensor of shape (*batch_shape, length).

  • reference (torch.Tensor) – Tensor of shape (n_ref, length).

Returns:

Tensor of shape (*batch_shape) containing the smallest edit distance for each candidate row.

Return type:

torch.Tensor

static integers_to_binary(tensor, k)

Convert a tensor of integers to their binary representation using k bits.

Parameters:
  • tensor (torch.Tensor) – A tensor containing integers.

  • k (int) – The number of bits to use for the binary representation of each integer.

Returns:

A tensor containing the binary representation of the input integers.

Return type:

torch.Tensor

log_reward(final_states)

Calculates the log-reward for the given final states.

Parameters:

final_states (BitSequenceStates) – The final states for which to calculate the log-reward.

Returns:

The calculated log-reward.

Return type:

torch.Tensor

make_modes_set(seed)

Generates a set of unique mode sequences based on the predefined tensor H.

Parameters:

seed – The seed for random number generation.

Returns:

A tensor containing the unique mode sequences.

Raises:

ValueError – If the number of requested modes exceeds the number of possible unique sequences.

Return type:

torch.Tensor

make_states_class()

Creates a BitSequenceStates class implementation.

Returns:

A BitSequenceStates class implementation.

Return type:

type[BitSequenceStates]

modes
n_actions = 17
n_modes: int = 60
property n_states3: int

Returns the total number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states.

Return type:

int

reset(batch_shape=None, sink=False)

Generates initial or sink states from batch_shape.

Parameters:
  • batch_shape (int | Tuple[int] | None) – The shape of the batch. If None, defaults to (1,). If an integer is provided, it is converted to a tuple.

  • sink (bool) – If True, sink state is created. Defaults to False.

Returns:

The initial states of the environment after reset.

Return type:

BitSequenceStates

reward(final_states)

Calculate the reward for the given final states.

The reward is computed based on the Hamming distance between the binary representation of the final states and the predefined modes. The reward is then scaled using an exponential function with a temperature parameter.

Parameters:

final_states (BitSequenceStates) – The final states for which the reward is to be calculated.

Returns:

The calculated reward for the given final states.

seq_size: int = 120
states_from_tensor(tensor, length=None)

Wraps the supplied Tensor in a States instance.

Parameters:
  • tensor (torch.Tensor) – The tensor of shape state_shape representing the states.

  • length (Optional[torch.Tensor]) – The length of each state in the tensor.

Returns:

An instance of States.

Return type:

BitSequenceStates

step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

BitSequenceStates

temperature = 1.0
property terminating_states: BitSequenceStates

Returns all terminating states of the environment.

Return type:

BitSequenceStates

trajectory_from_terminating_states(terminating_states_tensor)

Generate trajectories from terminating states.

This works because the DAG is a tree in the append-only version of BitSequence.

Parameters:

terminating_states_tensor (torch.Tensor) – A tensor containing the terminating states from which to generate the trajectories. The shape of the tensor should be (batch_size, words_per_seq).

Returns:

An object containing the generated trajectories.

Return type:

gfn.containers.Trajectories

true_dist(condition=None)

Returns the true probability mass function of the reward distribution.

Return type:

torch.Tensor

word_size: int = 4
words_per_seq: int = 30
class gfn.gym.BitSequencePlus(word_size=4, seq_size=120, n_modes=60, temperature=1.0, H=None, device_str='cpu', seed=0)

Bases: BitSequence

Prepend-Append version of BitSequence env.

This environment is similar to BitSequence, but allows to prepend and append words to the sequence.

Parameters:
  • word_size (int)

  • seq_size (int)

  • n_modes (int)

  • temperature (float)

  • H (Optional[torch.Tensor])

  • device_str (str)

  • seed (int)

H = None
backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

BitSequenceStates

make_states_class()

Creates a BitSequenceStates class implementation for BitSequencePlus.

Returns:

A BitSequenceStates class implementation with prepend-append mask logic.

Return type:

type[BitSequenceStates]

modes
n_modes: int = 60
seq_size: int = 120
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

BitSequenceStates

temperature = 1.0
abstract trajectory_from_terminating_states(terminating_states_tensor)

Generates trajectories from terminating states. Not implemented for this environment.

Parameters:

terminating_states_tensor (torch.Tensor) – A tensor of terminating states.

Raises:

NotImplementedError – This method is not implemented for this environment.

Return type:

gfn.containers.Trajectories

word_size: int = 4
words_per_seq: int = 30
gfn.gym.Box
class gfn.gym.BoxCartesian

Bases: gfn.gym.box.BoxPolar

Box environment with Cartesian per-dimension action validation.

Inherits all behavior from BoxPolar (init, step, backward_step, reward, log_partition, norm, make_random_states). Overrides only is_action_valid() to use per-dimension bounds instead of polar norm constraints.

Use with the Cartesian estimators/distributions in box_cartesian_utils.py.

See also

BoxPolar for the original polar norm-based variant.

is_action_valid(states, actions, backward=False)

Checks if the actions are valid (Cartesian per-dimension semantics).

For Cartesian actions: - Forward from s0: action[i] >= 0 and action[i] <= 1 - Forward from non-s0: action[i] >= delta and state[i] + action[i] <= 1 - Backward: state[i] - action[i] >= 0 - Backward to s0: if all resulting dims < delta, action must equal state

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

class gfn.gym.BoxPolar(delta=0.1, R0=0.1, R1=0.5, R2=2.0, epsilon=0.0001, device='cpu', debug=False)

Bases: gfn.env.Env

Box environment with polar (norm-based) action validation.

Corresponds to the environment in Section 4.1 of https://arxiv.org/abs/2301.12594

Actions are 2D vectors whose L2 norm must equal delta (for non-s0 forward steps) or be at most delta (for the initial s0 step). Use with the polar estimators/distributions in box_polar_utils.py.

See also

BoxCartesian for a simpler per-dimension Cartesian variant.

Parameters:
  • delta (float)

  • R0 (float)

  • R1 (float)

  • R2 (float)

  • epsilon (float)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

delta

The step size.

R0

The base reward.

R1

The reward for being outside the first box.

R2

The reward for being inside the second box.

epsilon

A small value to avoid numerical issues.

device

The device to use.

Type:

Literal[“cpu”, “cuda”] | torch.device

Return type:

torch.device

R0 = 0.1
R1 = 0.5
R2 = 2.0
backward_step(states, actions)

Backward step function for the Box environment.

Parameters:
Returns:

The previous states as a States object.

Return type:

gfn.states.States

delta = 0.1
epsilon = 0.0001
is_action_valid(states, actions, backward=False)

Checks if the actions are valid (polar norm-based semantics).

For polar actions: - Forward from s0: norm(action) <= delta - Forward from non-s0: norm(action) == delta (within tolerance) - Backward: state - action >= 0 component-wise - Backward to s0: if norm(state) < delta, action must equal state

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states tensor of shape (*batch_shape, 2).

Parameters:
  • batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A States object with random states.

Return type:

gfn.states.States

static norm(x)

Computes the L2 norm of the input tensor along the last dimension.

Parameters:

x (torch.Tensor) – Input tensor of shape (*batch_shape, 2).

Returns:

Normalized tensor of shape batch_shape.

Return type:

torch.Tensor

reward(final_states)

Reward is distance from the goal point.

Parameters:

final_states (gfn.states.States) – States object representing the final states.

Returns:

The reward tensor of shape batch_shape.

Return type:

torch.Tensor

step(states, actions)

Step function for the Box environment.

Parameters:
Returns:

The next states as a States object.

Return type:

gfn.states.States

class gfn.gym.ChipDesign(netlist_file=SAMPLE_NETLIST_FILE, init_placement=SAMPLE_INIT_PLACEMENT, std_cell_placer_mode='fd', wirelength_weight=1.0, density_weight=1.0, congestion_weight=0.5, device=None, debug=False, singularity_image=None, cd_finetune=True, reward_norm=None, reward_temper=1.0, reward_norm_samples=100, cost_stats=None)

Bases: gfn.env.DiscreteEnv

GFlowNet environment for chip placement.

The state is a vector of length n_macros, where state[i] is the grid cell location of the i-th macro to be placed. Unplaced macros have a location of -1.

Actions are integers from 0 to n_grid_cells - 1, representing the grid cell to place the current macro on. Action n_grid_cells is the exit action.

Parameters:
  • netlist_file (str)

  • init_placement (str)

  • std_cell_placer_mode (str)

  • wirelength_weight (float)

  • density_weight (float)

  • congestion_weight (float)

  • device (str | None)

  • debug (bool)

  • singularity_image (Optional[str])

  • cd_finetune (bool)

  • reward_norm (Optional[str])

  • reward_temper (float)

  • reward_norm_samples (int)

  • cost_stats (Optional[CostStats])

States: type[ChipDesignStates]
__del__()
Return type:

None

_apply_state_to_plc(state_tensor)

Applies a single state tensor to the plc object.

Parameters:

state_tensor (torch.Tensor)

_backward_step(states, actions)

Wraps parent _backward_step and updates masks.

Parameters:
Return type:

ChipDesignStates

_estimate_cost_stats(n_samples)

Estimates cost distribution from random placements.

Parameters:

n_samples (int)

Return type:

None

_hard_macro_indices
_normalize_cost(cost)

Applies reward normalization to a raw cost value.

Parameters:

cost (float)

Return type:

float

_sorted_node_indices
_step(states, actions)

Wraps parent _step and updates masks.

Parameters:
Return type:

ChipDesignStates

analytical_placer()

Places standard cells using an analytical placer.

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Return type:

ChipDesignStates

cd_finetune = True
close()

Closes the PlacementCost subprocess to free resources.

Return type:

None

congestion_weight = 0.5
density_weight = 1.0
log_reward(final_states)

Computes the log reward of the final states.

Parameters:

final_states (ChipDesignStates)

Return type:

torch.Tensor

make_states_class()

Creates the ChipDesignStates class.

Return type:

type[ChipDesignStates]

n_grid_cells
n_macros
plc
reset(batch_shape, random=False, sink=False, seed=None, conditions=None)

Resets the environment and computes initial masks.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis])

  • random (bool)

  • sink (bool)

  • seed (Optional[int])

  • conditions (Optional[torch.Tensor])

Return type:

ChipDesignStates

reward_norm = None
reward_temper = 1.0
std_cell_placer_mode = 'fd'
step(states, actions)

Performs a forward step in the environment.

Parameters:
Return type:

ChipDesignStates

update_masks(states)

Updates the forward and backward masks of the states.

Parameters:

states (ChipDesignStates)

Return type:

None

wirelength_weight = 1.0
class gfn.gym.ConditionalHyperGrid(*args, **kwargs)

Bases: HyperGrid

HyperGrid environment with condition-aware rewards.

Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)

_log_partition_cache: dict[torch.Tensor, float]
_max_reward: float
_original_reward_fn
_true_dist_cache: dict[torch.Tensor, torch.Tensor]
condition_dim: int = 1
is_conditional: bool = True
log_partition(condition)

Compute the log partition for the given condition.

Parameters:

condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)

Returns:

The log partition function, as a float.

Return type:

float

reward(states)

Compute rewards for the conditional environment.

A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original

Parameters:

states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)

Returns:

A tensor of shape (*batch_shape,) containing the rewards.

Return type:

torch.Tensor

sample_conditions(batch_shape)

Sample conditions for the environment.

Parameters:

batch_shape (int | tuple[int, Ellipsis])

Return type:

torch.Tensor

true_dist(condition)

Compute the true distribution for the given condition.

Parameters:
  • condition (torch.Tensor) – The condition to compute the true distribution for.

  • be (condition.shape should)

Returns:

The true distribution for the given condition as a 1-dimensional tensor.

Return type:

torch.Tensor

class gfn.gym.DiscreteEBM(ndim, energy=None, alpha=1.0, device='cpu', debug=False)

Bases: gfn.env.DiscreteEnv

Environment for discrete energy-based models.

This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf.

The states are represented as 1d tensors of length ndim with values in {-1, 0, 1}. s0 is empty (represented as -1), so s0=[-1, -1, …, -1]. An action corresponds to replacing a -1 with a 0 or a 1. Action i in [0, ndim - 1] corresponds to replacing s[i] with 0. Action i in [ndim, 2 * ndim - 1] corresponds to replacing s[i - ndim] with 1. The last action is the exit action that is only available for complete states (those with no -1).

Parameters:
  • ndim (int)

  • energy (EnergyFunction | None)

  • alpha (float)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

ndim

Dimension D of the sampling space {0, 1}^D.

Type:

int

energy

Energy function of the EBM.

Type:

EnergyFunction

alpha

Interaction strength the EBM.

Type:

float

States: type[gfn.states.DiscreteStates]
property all_states: gfn.states.DiscreteStates

Returns all possible states of the environment.

Return type:

gfn.states.DiscreteStates

alpha = 1.0
backward_step(states, actions)

Performs a backward step.

In this env, states are n-dim vectors. s0 is empty (represented as -1), so s0=[-1, -1, …, -1], each action is replacing a -1 with either a 0 or 1. Action i in [0, ndim-1] is replacing s[i] with 0, whereas action i in [ndim, 2*ndim-1] corresponds to replacing s[i - ndim] with 1. A backward action asks “what index should be set back to -1”, hence the fmod to enable wrapping of indices.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.States

energy: EnergyFunction = None
get_states_indices(states)

Given that each state is of length ndim with values in {-1, 0, 1}, there are 3**ndim states, which we can label from 0 to 3**ndim - 1.

The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in {0, 1, 2}. We thus need to shift this number by 1 so that {-1, 0, 1} -> {0, 1, 2}.

Parameters:

states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.

Returns:

The states indices as tensor of shape (*batch_shape).

Return type:

torch.Tensor

get_terminating_states_indices(states)

Given that each terminating state is of length ndim with values in {0, 1}, there are 2**ndim terminating states, which we can label from 0 to 2**ndim - 1.

The easiest way to map each state to a unique integer is to consider the state as a number in base 2.

Parameters:

states (gfn.states.DiscreteStates) – DiscreteStates object representing the states.

Returns:

The indices of the terminating states as tensor of shape (*batch_shape).

Return type:

torch.Tensor

is_exit_actions(actions)

Determines if the actions are exit actions.

Parameters:

actions (torch.Tensor) – tensor of actions of shape (*batch_shape, *action_shape).

Returns:

Tensor of booleans of shape (*batch_shape).

Return type:

torch.Tensor

log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float

log_reward(final_states)

The energy weighted by alpha is our log reward.

Parameters:

final_states (gfn.states.DiscreteStates) – DiscreteStates object representing the final states.

Returns:

The log reward as tensor of shape (*batch_shape).

Return type:

torch.Tensor

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states tensor of shape (*batch_shape, ndim).

Parameters:
  • batch_shape (Tuple) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A DiscreteStates object with random states.

Return type:

gfn.states.DiscreteStates

make_states_class()

Returns the DiscreteStates class for the DiscreteEBM environment.

Return type:

type[gfn.states.DiscreteStates]

property n_states: int

Returns the number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states in the environment.

Return type:

int

ndim
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states (gfn.states.DiscreteStates) – A batch of final states.

Returns:

A tensor of rewards.

Return type:

torch.Tensor

step(states, actions)

Performs a step.

Parameters:
Returns:

The next states as a States object.

Return type:

gfn.states.States

property terminating_states: gfn.states.DiscreteStates

Returns all terminating states of the environment.

Return type:

gfn.states.DiscreteStates

true_dist(condition=None)

Returns the true probability mass function of the reward distribution.

Return type:

torch.Tensor

class gfn.gym.GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed=True, max_nodes=None, device='cpu', s0=None, sf=None, debug=False)

Bases: gfn.env.GraphEnv

Environment for incrementally building graphs.

This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT)

Parameters:
  • num_node_classes (int)

  • num_edge_classes (int)

  • state_evaluator (Callable[[gfn.states.GraphStates], torch.Tensor])

  • is_directed (bool)

  • max_nodes (int | None)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • s0 (torch_geometric.data.Data | None)

  • sf (torch_geometric.data.Data | None)

  • debug (bool)

num_node_classes

The number of node classes.

num_edge_classes

The number of edge classes.

state_evaluator

A callable that computes rewards for final states.

is_directed

Whether the graph is directed.

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.GraphStates

is_action_valid(states, actions, backward=False)

Check if actions are valid for the given states.

Parameters:
Returns:

True if all actions are valid, False otherwise.

Return type:

bool

make_actions_class()

Returns the GraphActions class for this environment.

Returns:

A type of a subclass of GraphActions with environment-specific functionalities.

Return type:

type[gfn.actions.GraphActions]

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generates random states.

Parameters:
  • batch_shape (Tuple) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A GraphStates object with random states.

Return type:

gfn.states.GraphStates

make_states_class()

Creates a GraphStates class for this environment.

Return type:

type[gfn.states.GraphStates]

max_nodes = None
reward(final_states)

The environment’s reward given a state.

Parameters:

final_states (gfn.states.GraphStates) – A batch of final states.

Returns:

A tensor of shape (batch_size,) containing the rewards.

Return type:

torch.Tensor

state_evaluator
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.GraphStates

class gfn.gym.GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug=False)

Bases: GraphBuilding

Environment for building graphs edge by edge with discrete action space.

The environment supports both directed and undirected graphs.

In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building.

The action space is discrete, with size: - For directed graphs: n_nodes^2 - n_nodes + 1 (all possible directed edges + exit). - For undirected graphs: (n_nodes^2 - n_nodes)/2 + 1 (upper triangle + exit).

Parameters:
  • n_nodes (int)

  • state_evaluator (callable)

  • directed (bool)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

n_nodes

The number of nodes in the graph.

Type:

int

n_possible_edges

The number of possible edges.

Type:

int

is_action_valid(states, actions, backward=False)

Checks if the actions are valid.

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Makes a batch of random graph states with fixed number of nodes.

Parameters:
  • batch_shape (Tuple) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A GraphStates object containing random graph states.

Return type:

gfn.states.GraphStates

make_states_class()

Creates a GraphStates class for this environment.

Return type:

type[gfn.states.GraphStates]

n_nodes
class gfn.gym.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)

Bases: gfn.env.DiscreteEnv

HyperGrid environment from the GFlowNets paper.

The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.

Parameters:
  • ndim (int)

  • height (int)

  • reward_fn_str (str)

  • reward_fn_kwargs (dict | None)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • calculate_partition (bool)

  • store_all_states (bool)

  • debug (bool)

  • validate_modes (bool)

  • mode_stats (Literal['none', 'approx', 'exact'])

  • mode_stats_samples (int)

ndim

The dimension of the grid.

height

The height of the grid.

reward_fn

The reward function.

calculate_partition

Whether to calculate the log partition function.

store_all_states

Whether to store all states.

validate_modes

Whether to check that at least one state reaches the mode threshold at init; raises if not.

mode_stats

One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.

mode_stats_samples

Number of random samples when mode_stats=”approx”.

States: type[gfn.states.DiscreteStates]
_all_states_tensor = None
_enumerate_all_states_tensor(batch_size=20000)

Enumerate all grid states, optionally storing them and computing log Z.

Iterates over the full Cartesian product {0, ..., H-1}^D in batches (via multiprocessing) to avoid materializing all H^D states at once.

Parameters:

batch_size (int) – Number of states per batch.

_exists_bitwise_xor(thr)

Feasibility and constructive check for BitwiseXORReward.

Steps: - For each tier, verify the GF(2) parity system has at least one

solution using Gaussian elimination modulo 2. If any tier is infeasible, no mode exists.

  • The all-zero configuration satisfies even-parity constraints, so if tiers are feasible we evaluate that point against the threshold with tolerance.

Parameters:

thr (float)

Return type:

bool

_exists_cosine(thr)

Analytic upper-bound check for CosineReward.

Idea: - The per-dimension factor is (cos(50·ax) + 1) · N(0,1)(5·ax) with

ax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value m.

  • The full reward upper bound is R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.

  • We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by mode_gamma.

Parameters:

thr (float)

Return type:

bool

_exists_fallback_random(thr)

Random sampling fallback.

Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.

Parameters:

thr (float)

Return type:

bool

_exists_multiplicative_coprime(thr)

Number-theoretic constructive check for MultiplicativeCoprimeReward.

Constructs a candidate state by factoring the target LCM (if any) over the allowed primes, assigning each prime power to a separate active dimension, and verifying coprimality and grid-bound constraints.

Parameters:

thr (float)

Return type:

bool

_exists_original_or_deceptive(thr)

Constructive check for OriginalReward and DeceptiveReward.

Intuition: - These rewards form rings/bands around the center when each coordinate

is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.

  • We translate those fractional band boundaries into integer indices via small inside/outside nudges (using EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.

  • If the reward at that candidate exceeds the threshold (with EPS_REWARD_CMP tolerance), we return True.

Parameters:

thr (float)

Return type:

bool

_exists_sparse(thr)

Constructive check for SparseReward.

This reward assigns positive mass only to a finite set of target configurations. When H>=2 and D>=1, a known target is the zero vector except for certain coordinates fixed at 1 or H-2. We probe a canonical target and confirm the threshold is not above its reward.

Parameters:

thr (float)

Return type:

bool

_generate_combinations_in_batches(ndim, max_val, batch_size)

Yield batches of the Cartesian product {0, …, max_val}^ndim.

Uses multiprocessing to avoid materializing the full product (size (max_val+1)^ndim) in memory.

Parameters:
  • ndim (int) – Number of dimensions (tuple length).

  • max_val (int) – Maximum coordinate value (inclusive).

  • batch_size (int) – Number of tuples per batch.

Yields:

An iterator of tuples for each batch.

_log_partition = None
_mode_reward_threshold()

Returns the reward threshold used to define a mode.

By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.

Return type:

float

_mode_stats_kind: str = 'none'
_modes_exist_quick_check()

Lightweight check that a mode-level state exists.

In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly

and checks against the threshold.

  1. Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.

  2. As a last resort, it samples a small batch of random states.

Return type:

bool

_modes_exist_quick_check_info()

Same as _modes_exist_quick_check but returns (ok, message).

Return type:

tuple[bool, str]

_n_mode_states_estimate: float | None = None
_n_mode_states_exact: int | None = None
static _solve_gf2_has_solution(A, c)

Return True if A x = c over GF(2) has at least one solution.

Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (0 = 1) indicates inconsistency.

Parameters:
  • A (torch.Tensor)

  • c (torch.Tensor)

Return type:

bool

_true_dist = None
_worker(task)

Return a slice of the Cartesian product for one batch.

Parameters:

task (tuple) – (values, ndim, start_idx, end_idx) where values is the list of coordinate values, ndim is the number of dimensions, and [start_idx, end_idx) is the range within the full product.

Return type:

itertools.islice

all_indices()

Generate all possible indices for the grid.

Returns:

A list of all possible indices for the grid.

Return type:

List[Tuple[int, Ellipsis]]

property all_states: gfn.states.DiscreteStates | None

Returns a tensor of all hypergrid states as a DiscreteStates instance.

Return type:

gfn.states.DiscreteStates | None

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.DiscreteStates

calculate_partition = False
get_states_indices(states)

Get the indices of the states in the canonical ordering.

Parameters:

states (gfn.states.DiscreteStates | torch.Tensor) – The states to get the indices of.

Returns:

The indices of the states in the canonical ordering.

Return type:

torch.Tensor

get_terminating_states_indices(states)

Get the indices of the terminating states in the canonical ordering.

Parameters:

states (gfn.states.DiscreteStates) – The states to get the indices of.

Returns:

The indices of the terminating states in the canonical ordering.

Return type:

torch.Tensor

height = 8
log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float | None

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Creates a batch of random states.

Parameters:
  • batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A DiscreteStates object with random states.

Return type:

gfn.states.DiscreteStates

make_states_class()

Returns the DiscreteStates class for the HyperGrid environment.

Return type:

type[gfn.states.DiscreteStates]

mode_mask(states)

Boolean mask indicating which states are in a mode.

A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).

Parameters:

states (gfn.states.DiscreteStates)

Return type:

torch.Tensor

modes_found(states)

Returns the set of canonical state indices for mode states in the batch.

Each mode state is identified by its unique canonical index (from get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.

Parameters:

states (gfn.states.DiscreteStates)

Return type:

set[int]

property n_mode_states: int | float | None

Number of states inside a mode (exact, approx, or None).

  • If mode_stats=”exact”, returns an exact integer count.

  • If mode_stats=”approx”, returns a floating-point estimate.

  • If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.

  • Otherwise, returns None.

Return type:

int | float | None

property n_modes: int | float | None

Returns the total number of mode states for this environment.

Equivalent to n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.

Return type:

int | float | None

property n_states: int

Returns the number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states in the environment.

Return type:

int

ndim = 2
reward(states)

Computes the reward for a batch of final states.

In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}

  • 0.5 rightrvert in (0.25, 0.5] right)

  • 2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`

Parameters:
Returns:

The reward of the final states.

Return type:

torch.Tensor

reward_fn
reward_fn_kwargs = None
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.DiscreteStates

store_all_states = False
property terminating_states: gfn.states.DiscreteStates | None

Returns all terminating states of the environment.

Return type:

gfn.states.DiscreteStates | None

true_dist(condition=None)

Returns the pmf over all states in the hypergrid.

Return type:

torch.Tensor | None

class gfn.gym.Line(mus, sigmas, init_value, n_sd=4.5, n_steps_per_trajectory=5, device='cpu', debug=False)

Bases: gfn.env.Env

Mixture of Gaussians Line environment.

Parameters:
  • mus (list)

  • sigmas (list)

  • init_value (float)

  • n_sd (float)

  • n_steps_per_trajectory (int)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • debug (bool)

mus

The means of the Gaussians.

sigmas

The standard deviations of the Gaussians.

n_sd

The number of standard deviations to consider for the bounds.

n_steps_per_trajectory

The number of steps per trajectory.

mixture

The mixture of Gaussians.

init_value

The initial value of the state.

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.States

init_value
is_action_valid(states, actions, backward=False)

Checks if the actions are valid.

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

torch.Tensor

log_reward(final_states)

Computes the log reward of the environment.

Parameters:

final_states (gfn.states.States) – The final states of the environment.

Returns:

The log reward.

Return type:

torch.Tensor

mixture
mus
n_sd = 4.5
n_steps_per_trajectory = 5
sigmas
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.States

class gfn.gym.NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2, reward_exponent=2.0, H=None, device_str='cpu', seed=0, debug=False)

Bases: gfn.env.DiscreteEnv

Non-autoregressive BitSequence environment.

In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled.

The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target “mode” sequences.

Parameters:
  • word_size (int) – Number of bits per word (e.g., 1 for single-bit actions).

  • seq_size (int) – Total number of bits in the sequence. Must be divisible by word_size.

  • n_modes (int) – Number of target mode sequences.

  • reward_exponent (float) – Controls reward sharpness. Higher values make the reward more peaked around the modes.

  • H (Optional[torch.Tensor]) – Optional tensor of shape (n_modes, seq_size) specifying the target modes in binary. If None, modes are generated randomly using block patterns.

  • device_str (str) – Device to use ("cpu" or "cuda").

  • seed (int) – Random seed for mode generation.

  • debug (bool) – If True, enable runtime guards (not compile-friendly).

word_size

Number of bits per word.

seq_size

Total number of bits.

words_per_seq

Number of word positions (seq_size // word_size).

n_words

Number of possible word values (2 ** word_size).

n_modes

Number of target modes.

reward_exponent

Reward sharpness parameter.

modes

Target mode sequences as a binary tensor of shape (n_modes, seq_size).

Example

>>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2)
>>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions
>>> env.n_actions
9
>>> # State shape: 4 word positions
>>> env.s0
tensor([-1, -1, -1, -1])
H = None
States: type[NonAutoregressiveBitSequenceStates]
_decode_action(action)

Decode a flat action index into (position, word) pair.

Parameters:

action (torch.Tensor) – Action tensor of shape (*batch_shape, 1).

Returns:

Tuple of (position, word) tensors, each of shape (*batch_shape, 1).

Return type:

Tuple[torch.Tensor, torch.Tensor]

static _integers_to_binary(tensor, k)

Convert a tensor of word integers to their binary representation.

Parameters:
  • tensor (torch.Tensor) – Integer tensor of shape (*batch_shape, words_per_seq) with values in {0, ..., 2^k - 1}.

  • k (int) – Number of bits per word.

Returns:

Binary tensor of shape (*batch_shape, words_per_seq * k) with values in {0, 1}.

Return type:

torch.Tensor

_make_modes(seed, device)

Generate target mode sequences in binary representation.

If H is provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper.

Parameters:
  • seed (int) – Random seed.

  • device (torch.device) – Device to place the modes tensor on.

Returns:

Binary tensor of shape (n_modes, seq_size) with values in {0, 1}.

Return type:

torch.Tensor

static _min_hamming_distance(candidates, references)

Compute minimum Hamming distance from each candidate to any reference.

Parameters:
  • candidates (torch.Tensor) – Binary tensor of shape (*batch_shape, seq_size).

  • references (torch.Tensor) – Binary tensor of shape (n_refs, seq_size).

Returns:

Tensor of shape (*batch_shape,) with the minimum distance.

Return type:

torch.Tensor

backward_step(states, actions)

Undo a word placement by clearing the position back to -1.

The backward action has the same encoding as the forward action: action = position * n_words + word. The word component is used to identify which position to clear.

Parameters:
Returns:

Previous states with the specified positions cleared.

Return type:

NonAutoregressiveBitSequenceStates

log_reward(final_states)

Compute log-reward based on Hamming distance to nearest mode.

The log-reward is:

log R(x) = -reward_exponent * min_d(x, modes) / seq_size

where min_d is the minimum bit-level Hamming distance between the completed sequence and any target mode.

Parameters:

final_states (NonAutoregressiveBitSequenceStates) – Terminal states with all positions filled.

Returns:

Log-reward tensor of shape (*batch_shape,).

Return type:

torch.Tensor

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Generate random partially-filled states.

Each position is independently either unfilled (-1) or filled with a random word value.

Parameters:
  • batch_shape (Tuple) – Shape of the batch.

  • conditions (Optional[torch.Tensor]) – Optional conditions tensor.

  • device (Optional[torch.device]) – Device to use.

  • debug (bool) – If True, enable debug mode.

Returns:

Random states.

Return type:

NonAutoregressiveBitSequenceStates

make_states_class()

Create the States class with environment-specific constants.

Return type:

type[NonAutoregressiveBitSequenceStates]

modes
n_modes_count = 2
property n_terminating_states: int

Total number of possible terminal states.

Return type:

int

n_words = 2
reward(final_states)

Compute reward as exp(log_reward).

Parameters:

final_states (NonAutoregressiveBitSequenceStates) – Terminal states.

Returns:

Reward tensor of shape (*batch_shape,).

Return type:

torch.Tensor

reward_exponent = 2.0
seq_size = 4
step(states, actions)

Place a word at the specified position.

The action encodes (position, word) as a flat index: action = position * n_words + word.

Parameters:
Returns:

Next states with the specified positions filled.

Return type:

NonAutoregressiveBitSequenceStates

property terminating_states: NonAutoregressiveBitSequenceStates

Enumerate all terminal states (only feasible for small environments).

Return type:

NonAutoregressiveBitSequenceStates

true_dist(condition=None)

Compute the true reward distribution over all terminal states.

Return type:

torch.Tensor

word_size = 1
words_per_seq = 4
class gfn.gym.PerfectBinaryTree(reward_fn, depth=4, device=None, debug=False)

Bases: gfn.env.DiscreteEnv

Perfect Tree Environment.

This environment is a perfect binary tree, where there is a bijection between trajectories and terminating states. Nodes are represented by integers, starting from 0 for the root. States are represented by a single integer tensor corresponding to the node index. Actions are integers: 0 (left child), 1 (right child), 2 (exit).

e.g.:

0 (root)

/

1 2

/ /

3 4 5 6

/ / / /

7 8 9 10 11 12 13 14 (terminating states if depth=3)

Recommended preprocessor: OneHotPreprocessor.

Parameters:
  • reward_fn (Callable)

  • depth (int)

  • device (Literal['cpu', 'cuda'] | torch.device | None)

  • debug (bool)

reward_fn

A function that computes the reward for a given state.

Type:

Callable

depth

The depth of the tree.

Type:

int

branching_factor

The branching factor of the tree.

Type:

int

n_actions

The number of actions.

Type:

int

n_nodes

The number of nodes in the tree.

Type:

int

transition_table

A dictionary that maps (state, action) to the next state.

Type:

dict

inverse_transition_table

A dictionary that maps (state, action) to the previous state.

Type:

dict

term_states

The terminating states.

Type:

DiscreteStates

States: type[gfn.env.DiscreteStates]
_build_tree()

Builds the tree and the transition tables.

Returns:

A tuple containing the transition table, the inverse transition table, and the terminating states.

Return type:

tuple[dict, dict, gfn.env.DiscreteStates]

property all_states: gfn.env.DiscreteStates

Returns all the states of the environment.

Return type:

gfn.env.DiscreteStates

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The previous states.

Return type:

gfn.env.DiscreteStates

branching_factor = 2
depth = 4
get_states_indices(states)

Returns the indices of the states.

Parameters:

states (gfn.states.States) – The states to get the indices of.

Returns:

The indices of the states.

make_states_class()

Returns the DiscreteStates class for the PerfectBinaryTree environment.

Return type:

type[gfn.env.DiscreteStates]

n_actions = 3
n_nodes = 31
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states – The final states.

Returns:

The reward of the final states.

reward_fn
s0
sf
step(states, actions)

Performs a step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The next states.

Return type:

gfn.env.DiscreteStates

property terminating_states: gfn.env.DiscreteStates

Returns the terminating states of the environment.

Return type:

gfn.env.DiscreteStates

class gfn.gym.SetAddition(n_items, max_items, reward_fn, fixed_length=False, device=None, debug=False)

Bases: gfn.env.DiscreteEnv

Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html)

The state is a binary vector of length n_items, where 1 indicates the presence of an item. Actions are integers from 0 to n_items - 1 to add the corresponding item, or n_items to exit. Adding an existing item is invalid. The trajectory must end when max_items are present.

Recommended preprocessor: IdentityPreprocessor.

Parameters:
  • n_items (int)

  • max_items (int)

  • reward_fn (Callable)

  • fixed_length (bool)

  • device (Literal['cpu', 'cuda'] | torch.device | None)

  • debug (bool)

n_items

The number of items in the set.

Type:

int

max_items

The maximum number of items that can be added to the set.

Type:

int

reward_fn

The reward function.

Type:

Callable

fixed_length

Whether the trajectories have a fixed length.

Type:

bool

States: type[gfn.env.DiscreteStates]
property all_states: gfn.env.DiscreteStates

Returns all the states of the environment.

Return type:

gfn.env.DiscreteStates

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The previous states.

Return type:

gfn.env.DiscreteStates

fixed_length = False
get_states_indices(states)

Returns the indices of the states.

Parameters:

states (gfn.env.DiscreteStates) – The states to get the indices of.

Returns:

The indices of the states.

make_states_class()

Returns the DiscreteStates class for the SetAddition environment.

Return type:

type[gfn.env.DiscreteStates]

max_traj_len
n_items
reward(final_states)

Computes the reward for a batch of final states.

Parameters:

final_states (gfn.env.DiscreteStates) – The final states.

Returns:

The reward of the final states.

Return type:

torch.Tensor

reward_fn
step(states, actions)

Performs a step in the environment.

Parameters:
  • states (gfn.env.DiscreteStates) – The current states.

  • actions (gfn.env.Actions) – The actions to take.

Returns:

The next states.

Return type:

gfn.env.DiscreteStates

property terminating_states: gfn.env.DiscreteStates

Returns the terminating states of the environment.

Return type:

gfn.env.DiscreteStates