gfn.env

Attributes

NonValidActionsError

Classes

DiscreteEnv

Base class for discrete environments, where states are defined in a discrete

Env

Base class for all environments.

GraphEnv

Base class for graph-based environments.

Module Contents

class gfn.env.DiscreteEnv(n_actions, s0, state_shape, dummy_action=None, exit_action=None, sf=None, debug=False)

Bases: Env, abc.ABC

Base class for discrete environments, where states are defined in a discrete space, and actions are represented by an integer in {0, …, n_actions - 1}, the last one being the exit action.

For a guide on creating your own environments, see the documentation at: guides/creating_environments.

For a complete example, see the HyperGrid environment in src/gfn/gym/hypergrid.py.

Parameters:
  • n_actions (int)

  • s0 (torch.Tensor)

  • state_shape (Tuple | int)

  • dummy_action (Optional[torch.Tensor])

  • exit_action (Optional[torch.Tensor])

  • sf (Optional[torch.Tensor])

  • debug (bool)

s0

Tensor of shape (*state_shape) representing the initial state.

sf

Tensor of shape (*state_shape) representing the sink (final) state.

n_actions

The number of actions in the environment.

state_shape

Tuple representing the shape of the states.

dummy_action

Tensor of shape (1,) representing the dummy action.

exit_action

Tensor of shape (1,) representing the exit action.

States

The States class associated with this environment.

Actions

The Actions class associated with this environment.

is_discrete

Class variable, whether the environment is discrete.

_add_logz_diff(validation_info, gflownet, validate_condition)

Compute |learned_logZ - true_logZ| and add to validation_info.

Parameters:
  • validation_info (Dict[str, float])

  • gflownet (gfn.gflownet.GFlowNet)

  • validate_condition (torch.Tensor | None)

Return type:

None

_backward_step(states, actions)

Wrapper for the user-defined backward_step function.

Parameters:
Returns:

The batch of previous discrete states.

Return type:

gfn.states.DiscreteStates

static _jsd(p, q)

Jensen-Shannon divergence between two discrete distributions.

Uses the convention 0 * log(0 / x) = 0 via masking so that zero- probability bins contribute nothing (no clamping, no mass distortion).

Parameters:
  • p (torch.Tensor) – First distribution (1-D, sums to ~1).

  • q (torch.Tensor) – Second distribution (1-D, sums to ~1).

Returns:

JSD in nats (base-e). Bounded in [0, ln(2)].

Return type:

float

_step(states, actions)

Wrapper for the user-defined step function.

Parameters:
Returns:

The batch of next discrete states.

Return type:

gfn.states.DiscreteStates

_warn_if_insufficient_samples(n_validation_samples)

Emit a warning if validation sample count is too low for the state space.

Parameters:

n_validation_samples (int)

Return type:

None

property all_states: gfn.states.DiscreteStates
Abstractmethod:

Return type:

gfn.states.DiscreteStates

Optional method to return a batch of all discrete states in the environment.

Returns:

A batch of all discrete states (batch_shape = (n_states,)).

Return type:

gfn.states.DiscreteStates

Note

self.get_states_indices(self.all_states) and torch.arange(self.n_states) should be equivalent.

abstract get_states_indices(states)

Optional method to return the indices of the states in the environment.

Parameters:

states (gfn.states.DiscreteStates) – The batch of states.

Returns:

Tensor of shape (*batch_shape) containing the indices of the states.

Return type:

torch.Tensor

get_terminating_state_dist(states)

Computes the empirical distribution over terminating states.

Uses vectorized scatter_add_ for efficient histogram computation.

Parameters:

states (gfn.states.DiscreteStates) – A batch of terminating DiscreteStates.

Returns:

A 1D CPU tensor of shape (n_terminating_states,) with empirical frequencies summing to 1.

Raises:
  • NotImplementedError – If the environment lacks get_terminating_states_indices or n_terminating_states.

  • ValueError – If states is empty.

Return type:

torch.Tensor

abstract get_terminating_states_indices(states)
Optional method to return the indices of the terminating states in the

environment.

Parameters:

states (gfn.states.DiscreteStates) – The batch of states.

Returns:

Tensor of shape (*batch_shape) containing the indices of the terminating states.

Return type:

torch.Tensor

is_action_valid(states, actions, backward=False)

Checks whether the actions are valid in the given discrete states.

Parameters:
Returns:

True if all actions are valid in the given states, False otherwise. When debug is False, returns True without checking to keep hot paths compile-friendly.

Return type:

bool

is_discrete: bool = True
make_actions_class()

Returns the Actions class for this environment.

Returns:

A type of a subclass of Actions with environment-specific functionalities.

Return type:

type[gfn.actions.Actions]

make_states_class()

Returns the DiscreteStates class for this environment.

Returns:

A type of a subclass of DiscreteStates with environment-specific functionalities.

Return type:

type[gfn.states.DiscreteStates]

n_actions
property n_states: int
Abstractmethod:

Return type:

int

Optional method to return the number of states in the environment.

Returns:

The number of states.

Return type:

int

property n_terminating_states: int
Abstractmethod:

Return type:

int

Optional method to return the number of terminating states in the environment.

Returns:

The number of terminating states.

Return type:

int

reset(batch_shape, random=False, sink=False, seed=None, conditions=None)

Instantiates a batch of random, initial, or sink states.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int or tuple).

  • random (bool) – If True, initialize states randomly.

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • seed (Optional[int]) – (Optional) Random seed for reproducibility.

  • conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.

Returns:

A batch of initial or sink states.

Return type:

gfn.states.DiscreteStates

s0: torch.Tensor
sf: torch.Tensor
states_from_batch_shape(batch_shape, random=False, sink=False, conditions=None)

Returns a batch of random, initial, or sink states with a given batch shape.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis]) – Tuple representing the shape of the batch of states.

  • random (bool) – If True, initialize states randomly.

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • conditions (torch.Tensor | None)

Returns:

A batch of random, initial, or sink states.

Return type:

DiscreteStates

states_from_tensor(tensor, conditions=None)

Wraps the supplied tensor in a DiscreteStates instance.

Parameters:
  • tensor (torch.Tensor) – Tensor of shape (*batch_shape, *state_shape) representing the states.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

Returns:

An instance of DiscreteStates.

Return type:

gfn.states.DiscreteStates

property terminating_states: gfn.states.DiscreteStates
Abstractmethod:

Return type:

gfn.states.DiscreteStates

Optional method to return a batch of all terminating states in the environment.

Returns:

A batch of all terminating states (batch_shape = (n_terminating_states,)).

Return type:

gfn.states.DiscreteStates

Note

self.get_terminating_states_indices(self.terminating_states) and torch.arange(self.n_terminating_states) should be equivalent.

validate(gflownet, n_validation_samples=1000, visited_terminating_states=None, validate_condition=None, sampling_chunk_size=5000, check_sample_sufficiency=True)

Evaluate a GFlowNet against this environment’s true distribution.

Always samples fresh from the current policy to produce an unbiased estimate. Computes L1 distance and Jensen-Shannon divergence between the empirical and true distributions. If the GFlowNet has a learned logZ and the environment implements log_partition, also reports the absolute difference.

Parameters:
  • gflownet (gfn.gflownet.GFlowNet) – The GFlowNet to evaluate.

  • n_validation_samples (int) – Number of fresh trajectories to sample.

  • visited_terminating_states (Optional[gfn.states.DiscreteStates]) – Deprecated. Ignored if passed; a DeprecationWarning is emitted.

  • validate_condition (torch.Tensor | None) – Optional condition tensor for conditional envs.

  • sampling_chunk_size (int) – Max trajectories to sample at once (avoids OOM).

  • check_sample_sufficiency (bool) – If True, emits a one-time warning when n_validation_samples is too small relative to the state space. Set False to suppress.

Returns:

(metrics_dict, sampled_terminating_states) where metrics_dict contains "l1_dist", "jsd", and optionally "logZ_diff".

Raises:

ValueError – If true_dist is unavailable, n_validation_samples is non-positive, or enumeration APIs are missing.

Return type:

Tuple[Dict[str, float], gfn.states.DiscreteStates | None]

class gfn.env.Env(s0, state_shape, action_shape, dummy_action, exit_action, sf=None, debug=False)

Bases: abc.ABC

Base class for all environments.

Environments define the state and action spaces, as well as the forward & backward transition and reward functions.

Parameters:
  • s0 (torch.Tensor | torch_geometric.data.Data)

  • state_shape (Tuple)

  • action_shape (Tuple)

  • dummy_action (torch.Tensor)

  • exit_action (torch.Tensor)

  • sf (Optional[torch.Tensor | torch_geometric.data.Data])

  • debug (bool)

s0

The initial state (tensor or GeometricData).

sf

The sink (final) state (tensor or GeometricData).

state_shape

Tuple representing the shape of the states.

action_shape

Tuple representing the shape of the actions.

dummy_action

Tensor representing the dummy action for padding.

exit_action

Tensor representing the exit action.

States

The States class associated with this environment.

Actions

The Actions class associated with this environment.

is_discrete

Class variable, whether the environment is discrete.

is_conditional

Class variable, whether the environment is conditional.

Actions
States
_backward_step(states, actions)

Wrapper for the user-defined backward_step function.

This wrapper ensures that the backward_step is called only on valid states and actions, and sets the states to the initial state when the action is not valid. It also ensures that the new states are a distinct object from the old states.

Parameters:
Returns:

A batch of previous states.

Return type:

gfn.states.States

_step(states, actions)

Wrapper for the user-defined step function.

This wrapper ensures that the step is called only on valid states and actions, and sets the states to the sink state when the action is exit. It also ensures that the new states are a distinct object from the old states.

Parameters:
Returns:

A batch of next states.

Return type:

gfn.states.States

action_shape
actions_from_batch_shape(batch_shape)

Returns a batch of dummy actions with the supplied batch shape.

Parameters:

batch_shape (Tuple) – Tuple representing the shape of the batch of actions.

Returns:

A batch of dummy actions.

Return type:

gfn.actions.Actions

actions_from_tensor(tensor)

Wraps the supplied tensor in an Actions instance.

Parameters:

tensor (torch.Tensor) – Tensor of shape (*action_shape) representing the actions.

Returns:

An Actions instance.

Return type:

gfn.actions.Actions

abstract backward_step(states, actions)

Backward transition function of the environment.

This method takes a batch of states and actions and returns a batch of previous states. It does not need to check whether the actions are valid or the states are sink states, because the _backward_step method wraps it and checks for validity.

Parameters:
Returns:

A batch of previous states.

Return type:

gfn.states.States

debug = False
property device: torch.device

The device on which the environment’s elements are stored.

Returns:

The device of the initial state.

Return type:

torch.device

dummy_action
exit_action
abstract is_action_valid(states, actions, backward=False)

Checks whether the actions are valid in the given states.

Parameters:
Returns:

True if all actions are valid in the given states, False otherwise.

Return type:

bool

is_conditional: bool = False
is_discrete: bool = False
abstract log_partition(condition=None)

Optional method to return the logarithm of the partition function.

Parameters:

condition (torch.Tensor | None) – Optional tensor of shape (condition_dim,) containing the condition.

Returns:

The log partition function.

Return type:

float

log_reward(states)

Returns the environment’s log of rewards for a batch of states.

This or reward must be implemented by the environment.

Parameters:

states (gfn.states.States) – A batch of states with a batch_shape.

Returns:

Tensor of shape (*batch_shape) containing the log rewards.

Return type:

torch.Tensor

make_actions_class()

Returns the Actions class for this environment.

Defines a custom Actions class that inherits from Actions and implements assumed methods. The make_actions_class method should be overwritten to achieve more environment-specific Actions functionalities.

Returns:

A type of a subclass of Actions with environment-specific functionalities.

Return type:

type[gfn.actions.Actions]

abstract make_random_states(batch_shape, conditions=None, device=None)

Optional method to return a batch of random states.

Parameters:
  • batch_shape (Tuple) – Tuple representing the shape of the batch of states.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the states on.

Returns:

A batch of random states.

Return type:

gfn.states.States

make_states_class()

Returns the States class for this environment.

Defines a custom States class that inherits from States and implements assumed methods. The make_states_class method should be overwritten to achieve more environment-specific States functionalities.

Returns:

A type of a subclass of States with environment-specific functionalities.

Return type:

type[gfn.states.States]

reset(batch_shape, random=False, sink=False, seed=None, conditions=None)

Instantiates a batch of random, initial, or sink states.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int, tuple, or list).

  • random (bool) – If True, initialize states randomly.

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • seed (Optional[int]) – (Optional) Random seed for reproducibility.

  • conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.

Returns:

A batch of initial or sink states.

Return type:

gfn.states.States

abstract reward(states)

Returns the environment’s rewards for a batch of states.

This or log_reward must be implemented by the environment.

Parameters:

states (gfn.states.States) – A batch of states with a batch_shape.

Returns:

Tensor of shape (*batch_shape) containing the rewards.

Return type:

torch.Tensor

s0
abstract sample_conditions(batch_shape)

Sample conditions for the environment. Required for conditional environments.

Parameters:

batch_shape (int | Tuple[int, Ellipsis]) – The shape of the batch of conditions to sample.

Returns:

A tensor of shape (*batch_shape, condition_dim) containing the conditions.

Return type:

torch.Tensor

sf
state_shape
states_from_batch_shape(batch_shape, random=False, sink=False, conditions=None)

Returns a batch of random, initial, or sink states with a given batch shape.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis]) – Tuple representing the shape of the batch of states.

  • random (bool) – If True, initialize states randomly (requires implementation).

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

Returns:

A batch of random, initial, or sink states.

Return type:

gfn.states.States

states_from_tensor(tensor, conditions=None)

Wraps the supplied tensor in a States instance.

Parameters:
  • tensor (torch.Tensor) – Tensor of shape (*batch_shape, *state_shape) representing the states.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

Returns:

A States instance.

Return type:

gfn.states.States

abstract step(states, actions)

Forward transition function of the environment.

This method takes a batch of states and actions and returns a batch of next states. It does not need to check whether the actions are valid or the states are sink states, because the _step method wraps it and checks for validity.

Parameters:
Returns:

A batch of next states.

Return type:

gfn.states.States

abstract true_dist(condition=None)

Optional method to return the true distribution.

Parameters:

condition (torch.Tensor | None) – Optional tensor of shape (condition_dim,) containing the condition.

Returns:

The true distribution as a 1-dimensional tensor.

Return type:

torch.Tensor

class gfn.env.GraphEnv(s0, sf, num_node_classes, num_edge_classes, is_directed, debug=False)

Bases: Env

Base class for graph-based environments.

Graph environments represent states as graphs (torch_geometric Data objects) and actions as graph modifications.

Parameters:
  • s0 (torch_geometric.data.Data)

  • sf (torch_geometric.data.Data)

  • num_node_classes (int)

  • num_edge_classes (int)

  • is_directed (bool)

  • debug (bool)

s0

GeometricData representing the initial graph state.

sf

GeometricData representing the sink (final) graph state.

num_node_classes

Number of node classes.

num_edge_classes

Number of edge classes.

is_directed

Whether the graph is directed.

States

The States class associated with this environment.

Actions

The Actions class associated with this environment.

Actions
States
abstract backward_step(states, actions)

Backward transition function of the graph environment.

This method takes a batch of graph states and actions and returns a batch of previous graph states.

Parameters:
Returns:

A batch of previous graph states.

Return type:

gfn.states.GraphStates

debug = False
property device: torch.device

The device on which the graph states are stored.

Returns:

The device of the initial graph state’s node features.

Return type:

torch.device

dummy_action
exit_action
is_directed
make_actions_class()

Returns the GraphActions class for this environment.

Returns:

A type of a subclass of GraphActions with environment-specific functionalities.

Return type:

type[gfn.actions.GraphActions]

abstract make_random_states(batch_shape, conditions=None, device=None)

Optional method to return a batch of random graph states.

Parameters:
  • batch_shape (int | Tuple) – Shape of the batch (int or tuple).

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the graph states on.

Returns:

A batch of random graph states.

Return type:

gfn.states.GraphStates

make_states_class()

Returns the GraphStates class for this environment.

Returns:

A type of a subclass of GraphStates with environment-specific functionalities.

Return type:

type[gfn.states.GraphStates]

num_edge_classes
num_node_classes
reset(batch_shape, random=False, sink=False, seed=None, conditions=None)

Instantiates a batch of random, initial, or sink graph states.

Parameters:
  • batch_shape (int | Tuple[int, Ellipsis]) – Shape of the batch (int or tuple).

  • random (bool) – If True, initialize states randomly.

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • seed (Optional[int]) – (Optional) Random seed for reproducibility.

  • conditions (Optional[torch.Tensor]) – (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions.

Returns:

A batch of random, initial, or sink graph states.

Return type:

gfn.states.GraphStates

s0: torch_geometric.data.Data
sf: torch_geometric.data.Data
abstract step(states, actions)

Forward transition function of the graph environment.

This method takes a batch of graph states and actions and returns a batch of next graph states.

Parameters:
Returns:

A batch of next graph states.

Return type:

gfn.states.GraphStates

gfn.env.NonValidActionsError