gfn.states

Attributes

logger

Classes

DiscreteStates

Base class for states of discrete environments.

GraphStates

Base class for graph-based state representations.

States

Base class for states, representing nodes in the DAG of a GFlowNet.

Functions

_assert_factory_accepts_debug(factory, factory_name)

Ensure the factory can accept a debug kwarg (explicit or via **kwargs).

Module Contents

class gfn.states.DiscreteStates(tensor, conditions=None, device=None, debug=False)

Bases: States, abc.ABC

Base class for states of discrete environments.

DiscreteStates provide forward_masks and backward_masks as cached properties that compute which actions are allowed at each state on demand. This approach (similar to GraphStates) makes slicing operations faster since masks don’t need to be sliced - they are recomputed only when accessed.

Subclasses must implement _compute_forward_masks and _compute_backward_masks to define the mask computation logic for their specific environment.

Parameters:
  • tensor (torch.Tensor)

  • conditions (Optional[torch.Tensor])

  • device (torch.device | None)

  • debug (bool)

n_actions

Number of possible actions.

device

The device on which the states are stored.

Return type:

torch.device

forward_masks

Property that returns boolean tensor of allowed forward actions.

Return type:

torch.Tensor

backward_masks

Property that returns boolean tensor of allowed backward actions.

Return type:

torch.Tensor

Compile-related expectations: - Inputs (state tensor) should already be on the target device with correct shapes;

debug can be used to validate during development/tests.

  • Masks are computed on-demand and cached; cache is invalidated when needed.

__getitem__(index)

Returns a subset of the discrete states.

Masks are computed on demand for the new states rather than being sliced, which makes this operation faster.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.

Returns:

A new DiscreteStates object with the selected states and conditions.

Return type:

DiscreteStates

__repr__()

Returns a detailed string representation of the DiscreteStates object.

Returns:

A string summary of the DiscreteStates object.

Return type:

str

__setitem__(index, states)

Sets particular discrete states.

Parameters:
  • index (int | Sequence[int] | Sequence[bool]) – Indices to set.

  • states (DiscreteStates) – DiscreteStates object containing the new states.

Return type:

None

_backward_masks_cache: torch.Tensor | None = None
_compute_backward_masks()

Computes backward action masks for the current states.

By default, all backward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic.

Returns:

Boolean tensor of shape (*batch_shape, n_actions - 1).

Return type:

torch.Tensor

_compute_forward_masks()

Computes forward action masks for the current states.

By default, all forward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic.

Returns:

Boolean tensor of shape (*batch_shape, n_actions).

Return type:

torch.Tensor

_forward_masks_cache: torch.Tensor | None = None
_invalidate_masks_cache()

Invalidates the cached masks, forcing recomputation on next access.

Return type:

None

classmethod _make_view(tensor, conditions=None, debug=False)

Fast constructor for slicing: extends base _make_view with mask cache initialization (left empty for on-demand computation).

Parameters:
  • tensor (torch.Tensor)

  • conditions (torch.Tensor | None)

  • debug (bool)

Return type:

DiscreteStates

property backward_masks: torch.Tensor

Returns backward action masks, computing and caching if needed.

Returns:

Boolean tensor of shape (*batch_shape, n_actions - 1) indicating which backward actions are allowed at each state.

Return type:

torch.Tensor

clone()

Returns a clone of the current instance.

Returns:

A new DiscreteStates object with the same data and conditions. Masks are recomputed on demand for the cloned states.

Return type:

DiscreteStates

extend(other)

Concatenates another DiscreteStates object along the batch dimension.

Parameters:

other (DiscreteStates) – DiscreteStates object to concatenate with.

Return type:

None

flatten()

Flattens the batch dimension of the discrete states.

Masks are computed on demand for the flattened states.

Returns:

A new DiscreteStates object with the batch dimension flattened.

Return type:

DiscreteStates

property forward_masks: torch.Tensor

Returns forward action masks, computing and caching if needed.

Returns:

Boolean tensor of shape (*batch_shape, n_actions) indicating which forward actions are allowed at each state.

Return type:

torch.Tensor

init_forward_masks(set_ones=True)

Initializes forward masks.

A convenience function for common mask operations.

Parameters:

set_ones (bool) – if True, forward masks are initialized to all ones. Otherwise, they are initialized to all zeros.

Return type:

None

n_actions: ClassVar[int]
pad_dim0_with_sf(required_first_dim)

Extends states along the first batch dimension with sink states.

Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a DiscreteStates object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) tensors. This is useful to extend trajectories of different lengths.

Parameters:

required_first_dim (int) – The size of the first batch dimension post-expansion.

Return type:

None

set_exit_masks(batch_idx)

Sets forward masks such that the only allowable next action is to exit.

A convenience function for common mask operations.

Parameters:

batch_idx (torch.Tensor) – A boolean index along the batch dimension, along which to enforce exits.

Return type:

None

Notes

  • Works for 1D or 2D batch shapes; batch_idx must match batch_shape.

  • Clears all actions for the selected batch entries, then sets only the exit action True via masked_fill to stay torch.compile friendly.

  • Does not move devices; expects masks/tensors already on the target device.

set_nonexit_action_masks(cond, allow_exit)

Masks denoting disallowed actions according to cond, appending the exit mask.

A convenience function for common mask operations.

Parameters:
  • cond (torch.Tensor) – a boolean of shape (batch_shape,) + (n_actions - 1,), which denotes which actions are *not allowed. For example, if a state element represents action count, and no action can be repeated more than 5 times, cond might be state.tensor > 5 (assuming count starts at 0).

  • allow_exit (bool) – sets whether exiting can happen at any point in the trajectory - if so, it should be set to True.

Return type:

None

Notes

  • Always resets forward_masks to all True before applying the new mask so updates do not leak across steps.

  • Works for 1D or 2D batch shapes; cond must match batch_shape.

  • Debug guards validate shape/dtype but should be off in compiled regions.

classmethod stack(states)

Stacks a list of DiscreteStates objects along a new dimension (0).

Parameters:

states (Sequence[DiscreteStates]) – List of DiscreteStates objects to stack.

Returns:

A new DiscreteStates object with the stacked states and conditions. Masks are computed on demand for the stacked states.

Return type:

DiscreteStates

to(device)

Moves the tensor to the specified device in-place.

Masks will be recomputed on the new device when accessed.

Parameters:

device (torch.device) – The device to move to.

Returns:

The DiscreteStates object on the specified device.

Return type:

DiscreteStates

class gfn.states.GraphStates(data, categorical_node_features=False, categorical_edge_features=False, conditions=None, device=None, debug=False)

Bases: States

Base class for graph-based state representations.

A GraphStates object is a collection of multiple graph objects stored as a numpy array of GeometricData objects. This supports batched management of graphs.

Parameters:
  • data (numpy.ndarray)

  • categorical_node_features (bool)

  • categorical_edge_features (bool)

  • conditions (torch.Tensor | None)

  • device (torch.device | None)

  • debug (bool)

num_node_classes

Number of node classes.

num_edge_classes

Number of edge classes.

is_directed

Whether the graph is directed.

s0

Initial state (graph).

sf

Final state (graph).

data

A numpy array of GeometricData objects representing individual graphs.

_device

The device on which the graphs are stored.

__getitem__(index)

Returns a subset of the GraphStates.

Parameters:

index (Union[int, Sequence[int], slice, torch.Tensor, Literal[1], Tuple]) – Index or indices to select.

Returns:

A new GraphStates object containing the selected graphs.

Return type:

GraphStates

__len__()

Returns the total number of graphs.

Returns:

The number of graphs in the batch.

Return type:

int

__repr__()

Returns a detailed string representation of the GraphStates object.

Returns:

A string summary of the GraphStates object.

Return type:

str

__setitem__(index, graph)

Sets a subset of the GraphStates.

Parameters:
  • index (Union[int, Sequence[int], slice, torch.Tensor, Tuple]) – Index or indices to set.

  • graph (GraphStates) – GraphStates object containing the new graphs.

Return type:

None

_compare(other)

Compares the current batch of graphs with another graph.

Note that this does not check if the conditions are equal.

Parameters:

other (torch_geometric.data.Data) – A GeometricData object to compare with.

Returns:

A boolean tensor of shape (*batch_shape,) indicating which graphs in the batch are equal to other.

Return type:

torch.Tensor

_compare_reference(ref)

Compares batch against a reference graph (s0 or sf), handling device mismatch.

Parameters:

ref (torch_geometric.data.Data)

Return type:

torch.Tensor

_conditions: torch.Tensor | None = None
_device
_get_index_np(index)

Converts a tensor-based index to a numpy index.

Parameters:

index (Union[int, Sequence[int], slice, torch.Tensor, Tuple]) – The index to convert.

Returns:

The converted index.

Return type:

Union[int, Sequence[int], slice, numpy.ndarray, Tuple]

property backward_masks: tensordict.TensorDict

Computes masks for valid backward actions from the current state.

A backward action is valid if:
  1. The edge exists in the current graph (i.e., can be removed)

  2. The node exists in the current graph and no edges are connected to it

For directed graphs, all existing edges are considered for removal. For undirected graphs, only the upper triangular edges are considered.

The EXIT action is not included in backward masks.

Returns:

Boolean mask where True indicates valid actions.

Return type:

TensorDict

property batch_shape: tuple[int, Ellipsis]

The batch shape of the graphs.

Returns:

The batch shape as a tuple.

Return type:

tuple[int, Ellipsis]

categorical_edge_features = False
categorical_node_features = False
clone()

Returns a detached clone of the current instance.

Returns:

A new GraphStates object with the same data.

Return type:

GraphStates

data
debug = False
property device: torch.device

The device on which the states are stored.

Returns:

The device of the underlying array of GeometricData.

Return type:

torch.device

extend(other)

Concatenates another GraphStates object along the batch dimension.

Parameters:

other (GraphStates) – GraphStates object to concatenate with.

property forward_masks: tensordict.TensorDict

Computes masks for valid forward actions from the current state.

A forward action is valid if:
  1. The edge doesn’t already exist in the graph

  2. The edge connects two distinct nodes

For directed graphs, all possible src->dst edges are considered. For undirected graphs, only the upper triangular portion of the adjacency matrix is used.

Returns:

Boolean mask where True indicates valid actions.

Return type:

TensorDict

is_directed: ClassVar[bool]
property is_initial_state: torch.Tensor

Returns a boolean tensor indicating which graphs are initial states (\(s_0\)).

Returns:

A boolean tensor of shape (*batch_shape,) that is True for initial states.

Return type:

torch.Tensor

property is_sink_state: torch.Tensor

Returns a boolean tensor indicating which graphs are sink states (\(s_f\)).

Returns:

A boolean tensor of shape (*batch_shape,) that is True for sink states.

Return type:

torch.Tensor

classmethod make_initial_states(batch_shape, conditions=None, device=None, debug=False)

Creates a numpy array of graphs consisting of initial states (\(s_0\)).

Parameters:
  • batch_shape (int | Tuple) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – Device to create the graphs on.

  • debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.

Returns:

A GraphStates object containing copies of the initial state.

Return type:

GraphStates

classmethod make_sink_states(batch_shape, conditions=None, device=None, debug=False)

Creates a numpy array of graphs consisting of sink states (\(s_f\)).

Parameters:
  • batch_shape (int | Tuple) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – Device to create the graphs on.

  • debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.

Returns:

A GraphStates object containing copies of the sink state.

Return type:

GraphStates

max_nodes: ClassVar[int | None]
num_edge_classes: ClassVar[int]
num_node_classes: ClassVar[int]
pad_dim0_with_sf(required_first_dim)

Extends a 2-dimensional batch of graph states along the first batch dimension.

Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a GraphStates object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) graphs. This is useful to extend trajectories of different lengths.

Parameters:

required_first_dim (int) – The size of the first batch dimension post-expansion.

Return type:

None

s0: ClassVar[torch_geometric.data.Data]
sf: ClassVar[torch_geometric.data.Data]
classmethod stack(states)

Stacks a list of GraphStates objects along a new dimension (0).

Parameters:

states (List[GraphStates]) – List of GraphStates objects to stack.

Returns:

A new GraphStates object with the stacked graphs and conditions.

Return type:

GraphStates

property tensor: gfn.utils.graphs.GeometricBatch

Returns the batch representation of the data as a GeometricBatch.

Returns:

A GeometricBatch object representing the batch of graphs.

Return type:

gfn.utils.graphs.GeometricBatch

to(device)

Moves the GraphStates to the specified device.

Parameters:

device (torch.device) – The device to move to.

Returns:

The GraphStates object on the specified device.

Return type:

GraphStates

class gfn.states.States(tensor, conditions=None, device=None, debug=False)

Bases: abc.ABC

Base class for states, representing nodes in the DAG of a GFlowNet.

Each environment needs to define a subclass of States. A States object is a collection of multiple states (nodes of the DAG) that supports batching. Generally, if a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g., state as a string, numpy array, graph, etc.), but these may need additional logic to support batching (see GraphStates below for an example).

Two useful subclasses of States are provided: - DiscreteStates for discrete environments, which represents discrete states

with a tensor of shape (*batch_shape, *state_shape).

  • GraphStates for graph-based environments, which represents graphs as a numpy object array of shape (*batch_shape,) containing GeometricData objects.

A batch_shape property keeps track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, batch_size).

Because multiple trajectories can have different lengths, batching requires appending a dummy state (\(sf\)) to trajectories that are shorter than the longest trajectory. This dummy state should never be processed, and is used to pad the batch of states only.

Compile-related expectations: - Hot paths should be called with tensors already on the target device and with

correct shapes; debug guards can be enabled during development/tests to validate.

  • Set debug=False inside torch.compile regions to avoid Python-side graph breaks; enable debug=True only when running eager checks.

Parameters:
  • tensor (torch.Tensor)

  • conditions (torch.Tensor | None)

  • device (torch.device | None)

  • debug (bool)

tensor

Tensor of shape (*batch_shape, *state_shape) representing a batch of states.

state_shape

Class variable, a tuple defining the shape of a single state.

s0

Class variable, a tensor of shape (*state_shape,) representing the initial state.

sf

Class variable, a tensor of shape (*state_shape,) representing the sink state.

make_random_states

Class variable, a callable that returns a random state. This is used to initialize random states.

__getitem__(index)

Returns a subset of the states along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.

Returns:

A new States object with the selected states and conditions.

Return type:

States

__len__()

Returns the number of states in the batch.

Returns:

The number of states.

Return type:

int

__repr__()

Returns a string representation of the States object.

Returns:

A string summary of the States object.

Return type:

str

__setitem__(index, states)

Sets particular states of the batch to a new States object.

Parameters:
  • index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to set.

  • states (States) – States object containing the new states.

Return type:

None

_compare(other)

Computes elementwise equality between state tensor and an external tensor.

Note that this does not check if the conditions are equal.

Parameters:
  • other (torch.Tensor) – Tensor with shape (*batch_shape, *state_shape) representing states to

  • to. (compare)

Returns:

A boolean tensor of shape (*batch_shape,) indicating whether the states are equal to other.

Return type:

torch.Tensor

_conditions: torch.Tensor | None = None
_is_initial_cache: torch.Tensor | None = None
_is_sink_cache: torch.Tensor | None = None
classmethod _make_view(tensor, conditions=None, debug=False)

Fast constructor for internal slicing operations.

Bypasses __init__’s device resolution, .to() dispatch, and conditions shape/dtype validation — all redundant when the source States object was already validated. Used by __getitem__ to avoid per-step overhead in the sampling loop.

Parameters:
  • tensor (torch.Tensor)

  • conditions (torch.Tensor | None)

  • debug (bool)

Return type:

States

_merge_conditions(other, self_was_empty)

Merges conditions after extending tensors.

When self was empty its conditions are None by construction, so we adopt the other’s conditions rather than warning about inconsistency. Symmetrically, extending with an empty other is a no-op for conditions.

Parameters:
  • other (States)

  • self_was_empty (bool)

Return type:

None

property batch_shape: tuple[int, Ellipsis]

The batch shape of the states.

Returns:

The batch shape as a tuple.

Return type:

tuple[int, Ellipsis]

clone()

Returns a clone of the current instance.

Returns:

A new States object with the same data and conditions.

Return type:

States

property conditions: torch.Tensor | None

The conditions attached to these states for conditional GFlowNets.

Returns:

Tensor of shape (*batch_shape, condition_dim) or None if no conditions.

Return type:

torch.Tensor | None

debug = False
property device: torch.device

The device on which the states are stored.

Returns:

The device of the underlying tensor.

Return type:

torch.device

extend(other)

Concatenates another States object along the final batch dimension.

Both States objects must have the same number of batch dimensions, which should be 1 or 2.

Parameters:

other (States) – States object to be concatenated to the current States object.

Return type:

None

flatten()

Flattens the batch dimension of the states.

Useful for example when extracting individual states from trajectories.

Returns:

A new States object with the batch dimension flattened.

Return type:

States

classmethod from_batch_shape(batch_shape, random=False, sink=False, conditions=None, device=None, debug=False)

Creates a States object with the given batch shape.

By default, all states are initialized to \(s_0\), the initial state. Optionally, one can initialize random state, which requires that the environment implements the make_random_states class method. Sink can be used to initialize states at \(s_f\), the sink state. Both random and sink cannot be True at the same time.

Parameters:
  • batch_shape (int | tuple[int, Ellipsis]) – Shape of the batch dimensions.

  • random (bool) – If True, initialize states randomly.

  • sink (bool) – If True, initialize states as sink states (\(s_f\)).

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the states on.

  • debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.

Returns:

A States object with the specified batch shape and initialization.

Return type:

States

property has_conditions: bool

Whether conditions are attached to these states.

Return type:

bool

property is_initial_state: torch.Tensor

Returns a boolean tensor indicating which states are initial (\(s_0\)).

Returns:

A boolean tensor of shape (*batch_shape,) that is True for initial states.

Return type:

torch.Tensor

property is_sink_state: torch.Tensor

Returns a boolean tensor indicating which states are sink (\(s_f\)).

Returns:

A boolean tensor of shape (*batch_shape,) that is True for sink states.

Return type:

torch.Tensor

classmethod make_initial_states(batch_shape, conditions=None, device=None, debug=False)

Creates a States object with all states set to \(s_0\).

Parameters:
  • batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the states on.

  • debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.

Returns:

A States object with all states set to \(s_0\).

Return type:

States

make_random_states: Callable
classmethod make_sink_states(batch_shape, conditions=None, device=None, debug=False)

Creates a States object with all states set to \(s_f\).

Parameters:
  • batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to create the states on.

  • debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.

Returns:

A States object with all states set to \(s_f\).

Return type:

States

pad_dim0_with_sf(required_first_dim)

Extends a 2-dimensional batch of states along the first batch dimension.

Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a States object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) tensors. This is useful to extend trajectories of different lengths.

Parameters:

required_first_dim (int) – The size of the first batch dimension post-expansion.

Return type:

None

s0: ClassVar[torch.Tensor | torch_geometric.data.Data]
sample(n_samples)

Randomly samples a subset of states from the batch.

Parameters:

n_samples (int) – The number of states to sample.

Returns:

A new States object with the sampled states.

Return type:

States

sf: ClassVar[torch.Tensor | torch_geometric.data.Data]
classmethod stack(states)

Stacks a list of States objects along a new dimension (0).

Parameters:

states (Sequence[States]) – List of States objects to stack.

Returns:

A new States object with the stacked states and conditions.

Return type:

States

state_shape: ClassVar[tuple[int, Ellipsis]]
tensor
to(device)

Moves the States tensor to the specified device in-place.

Parameters:

device (torch.device) – The device to move to.

Returns:

The States object on the specified device.

Return type:

States

gfn.states._assert_factory_accepts_debug(factory, factory_name)

Ensure the factory can accept a debug kwarg (explicit or via **kwargs).

Parameters:
  • factory (Callable)

  • factory_name (str)

Return type:

None

gfn.states.logger