gfn.states¶
Attributes¶
Classes¶
Base class for states of discrete environments. |
|
Base class for graph-based state representations. |
|
Base class for states, representing nodes in the DAG of a GFlowNet. |
Functions¶
|
Ensure the factory can accept a debug kwarg (explicit or via **kwargs). |
Module Contents¶
- class gfn.states.DiscreteStates(tensor, conditions=None, device=None, debug=False)¶
Bases:
States,abc.ABCBase class for states of discrete environments.
DiscreteStates provide forward_masks and backward_masks as cached properties that compute which actions are allowed at each state on demand. This approach (similar to GraphStates) makes slicing operations faster since masks don’t need to be sliced - they are recomputed only when accessed.
Subclasses must implement _compute_forward_masks and _compute_backward_masks to define the mask computation logic for their specific environment.
- Parameters:
tensor (torch.Tensor)
conditions (Optional[torch.Tensor])
device (torch.device | None)
debug (bool)
- n_actions¶
Number of possible actions.
- device¶
The device on which the states are stored.
- Return type:
torch.device
- forward_masks¶
Property that returns boolean tensor of allowed forward actions.
- Return type:
torch.Tensor
- backward_masks¶
Property that returns boolean tensor of allowed backward actions.
- Return type:
torch.Tensor
Compile-related expectations: - Inputs (state tensor) should already be on the target device with correct shapes;
debug can be used to validate during development/tests.
Masks are computed on-demand and cached; cache is invalidated when needed.
- __getitem__(index)¶
Returns a subset of the discrete states.
Masks are computed on demand for the new states rather than being sliced, which makes this operation faster.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.
- Returns:
A new DiscreteStates object with the selected states and conditions.
- Return type:
- __repr__()¶
Returns a detailed string representation of the DiscreteStates object.
- Returns:
A string summary of the DiscreteStates object.
- Return type:
str
- __setitem__(index, states)¶
Sets particular discrete states.
- Parameters:
index (int | Sequence[int] | Sequence[bool]) – Indices to set.
states (DiscreteStates) – DiscreteStates object containing the new states.
- Return type:
None
- _backward_masks_cache: torch.Tensor | None = None¶
- _compute_backward_masks()¶
Computes backward action masks for the current states.
By default, all backward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic.
- Returns:
Boolean tensor of shape (*batch_shape, n_actions - 1).
- Return type:
torch.Tensor
- _compute_forward_masks()¶
Computes forward action masks for the current states.
By default, all forward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic.
- Returns:
Boolean tensor of shape (*batch_shape, n_actions).
- Return type:
torch.Tensor
- _forward_masks_cache: torch.Tensor | None = None¶
- _invalidate_masks_cache()¶
Invalidates the cached masks, forcing recomputation on next access.
- Return type:
None
- classmethod _make_view(tensor, conditions=None, debug=False)¶
Fast constructor for slicing: extends base _make_view with mask cache initialization (left empty for on-demand computation).
- Parameters:
tensor (torch.Tensor)
conditions (torch.Tensor | None)
debug (bool)
- Return type:
- property backward_masks: torch.Tensor¶
Returns backward action masks, computing and caching if needed.
- Returns:
Boolean tensor of shape (*batch_shape, n_actions - 1) indicating which backward actions are allowed at each state.
- Return type:
torch.Tensor
- clone()¶
Returns a clone of the current instance.
- Returns:
A new DiscreteStates object with the same data and conditions. Masks are recomputed on demand for the cloned states.
- Return type:
- extend(other)¶
Concatenates another DiscreteStates object along the batch dimension.
- Parameters:
other (DiscreteStates) – DiscreteStates object to concatenate with.
- Return type:
None
- flatten()¶
Flattens the batch dimension of the discrete states.
Masks are computed on demand for the flattened states.
- Returns:
A new DiscreteStates object with the batch dimension flattened.
- Return type:
- property forward_masks: torch.Tensor¶
Returns forward action masks, computing and caching if needed.
- Returns:
Boolean tensor of shape (*batch_shape, n_actions) indicating which forward actions are allowed at each state.
- Return type:
torch.Tensor
- init_forward_masks(set_ones=True)¶
Initializes forward masks.
A convenience function for common mask operations.
- Parameters:
set_ones (bool) – if True, forward masks are initialized to all ones. Otherwise, they are initialized to all zeros.
- Return type:
None
- n_actions: ClassVar[int]¶
- pad_dim0_with_sf(required_first_dim)¶
Extends states along the first batch dimension with sink states.
Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a DiscreteStates object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) tensors. This is useful to extend trajectories of different lengths.
- Parameters:
required_first_dim (int) – The size of the first batch dimension post-expansion.
- Return type:
None
- set_exit_masks(batch_idx)¶
Sets forward masks such that the only allowable next action is to exit.
A convenience function for common mask operations.
- Parameters:
batch_idx (torch.Tensor) – A boolean index along the batch dimension, along which to enforce exits.
- Return type:
None
Notes
Works for 1D or 2D batch shapes; batch_idx must match batch_shape.
Clears all actions for the selected batch entries, then sets only the exit action True via masked_fill to stay torch.compile friendly.
Does not move devices; expects masks/tensors already on the target device.
- set_nonexit_action_masks(cond, allow_exit)¶
Masks denoting disallowed actions according to cond, appending the exit mask.
A convenience function for common mask operations.
- Parameters:
cond (torch.Tensor) – a boolean of shape (batch_shape,) + (n_actions - 1,), which denotes which actions are *not allowed. For example, if a state element represents action count, and no action can be repeated more than 5 times, cond might be state.tensor > 5 (assuming count starts at 0).
allow_exit (bool) – sets whether exiting can happen at any point in the trajectory - if so, it should be set to True.
- Return type:
None
Notes
Always resets forward_masks to all True before applying the new mask so updates do not leak across steps.
Works for 1D or 2D batch shapes; cond must match batch_shape.
Debug guards validate shape/dtype but should be off in compiled regions.
- classmethod stack(states)¶
Stacks a list of DiscreteStates objects along a new dimension (0).
- Parameters:
states (Sequence[DiscreteStates]) – List of DiscreteStates objects to stack.
- Returns:
A new DiscreteStates object with the stacked states and conditions. Masks are computed on demand for the stacked states.
- Return type:
- to(device)¶
Moves the tensor to the specified device in-place.
Masks will be recomputed on the new device when accessed.
- Parameters:
device (torch.device) – The device to move to.
- Returns:
The DiscreteStates object on the specified device.
- Return type:
- class gfn.states.GraphStates(data, categorical_node_features=False, categorical_edge_features=False, conditions=None, device=None, debug=False)¶
Bases:
StatesBase class for graph-based state representations.
A GraphStates object is a collection of multiple graph objects stored as a numpy array of GeometricData objects. This supports batched management of graphs.
- Parameters:
data (numpy.ndarray)
categorical_node_features (bool)
categorical_edge_features (bool)
conditions (torch.Tensor | None)
device (torch.device | None)
debug (bool)
- num_node_classes¶
Number of node classes.
- num_edge_classes¶
Number of edge classes.
- is_directed¶
Whether the graph is directed.
- s0¶
Initial state (graph).
- sf¶
Final state (graph).
- data¶
A numpy array of GeometricData objects representing individual graphs.
- _device¶
The device on which the graphs are stored.
- __getitem__(index)¶
Returns a subset of the GraphStates.
- Parameters:
index (Union[int, Sequence[int], slice, torch.Tensor, Literal[1], Tuple]) – Index or indices to select.
- Returns:
A new GraphStates object containing the selected graphs.
- Return type:
- __len__()¶
Returns the total number of graphs.
- Returns:
The number of graphs in the batch.
- Return type:
int
- __repr__()¶
Returns a detailed string representation of the GraphStates object.
- Returns:
A string summary of the GraphStates object.
- Return type:
str
- __setitem__(index, graph)¶
Sets a subset of the GraphStates.
- Parameters:
index (Union[int, Sequence[int], slice, torch.Tensor, Tuple]) – Index or indices to set.
graph (GraphStates) – GraphStates object containing the new graphs.
- Return type:
None
- _compare(other)¶
Compares the current batch of graphs with another graph.
Note that this does not check if the conditions are equal.
- Parameters:
other (torch_geometric.data.Data) – A GeometricData object to compare with.
- Returns:
A boolean tensor of shape (*batch_shape,) indicating which graphs in the batch are equal to other.
- Return type:
torch.Tensor
- _compare_reference(ref)¶
Compares batch against a reference graph (s0 or sf), handling device mismatch.
- Parameters:
ref (torch_geometric.data.Data)
- Return type:
torch.Tensor
- _conditions: torch.Tensor | None = None¶
- _device¶
- _get_index_np(index)¶
Converts a tensor-based index to a numpy index.
- Parameters:
index (Union[int, Sequence[int], slice, torch.Tensor, Tuple]) – The index to convert.
- Returns:
The converted index.
- Return type:
Union[int, Sequence[int], slice, numpy.ndarray, Tuple]
- property backward_masks: tensordict.TensorDict¶
Computes masks for valid backward actions from the current state.
- A backward action is valid if:
The edge exists in the current graph (i.e., can be removed)
The node exists in the current graph and no edges are connected to it
For directed graphs, all existing edges are considered for removal. For undirected graphs, only the upper triangular edges are considered.
The EXIT action is not included in backward masks.
- Returns:
Boolean mask where True indicates valid actions.
- Return type:
TensorDict
- property batch_shape: tuple[int, Ellipsis]¶
The batch shape of the graphs.
- Returns:
The batch shape as a tuple.
- Return type:
tuple[int, Ellipsis]
- categorical_edge_features = False¶
- categorical_node_features = False¶
- clone()¶
Returns a detached clone of the current instance.
- Returns:
A new GraphStates object with the same data.
- Return type:
- data¶
- debug = False¶
- property device: torch.device¶
The device on which the states are stored.
- Returns:
The device of the underlying array of GeometricData.
- Return type:
torch.device
- extend(other)¶
Concatenates another GraphStates object along the batch dimension.
- Parameters:
other (GraphStates) – GraphStates object to concatenate with.
- property forward_masks: tensordict.TensorDict¶
Computes masks for valid forward actions from the current state.
- A forward action is valid if:
The edge doesn’t already exist in the graph
The edge connects two distinct nodes
For directed graphs, all possible src->dst edges are considered. For undirected graphs, only the upper triangular portion of the adjacency matrix is used.
- Returns:
Boolean mask where True indicates valid actions.
- Return type:
TensorDict
- is_directed: ClassVar[bool]¶
- property is_initial_state: torch.Tensor¶
Returns a boolean tensor indicating which graphs are initial states (\(s_0\)).
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for initial states.
- Return type:
torch.Tensor
- property is_sink_state: torch.Tensor¶
Returns a boolean tensor indicating which graphs are sink states (\(s_f\)).
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for sink states.
- Return type:
torch.Tensor
- classmethod make_initial_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a numpy array of graphs consisting of initial states (\(s_0\)).
- Parameters:
batch_shape (int | Tuple) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – Device to create the graphs on.
debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.
- Returns:
A GraphStates object containing copies of the initial state.
- Return type:
- classmethod make_sink_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a numpy array of graphs consisting of sink states (\(s_f\)).
- Parameters:
batch_shape (int | Tuple) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – Device to create the graphs on.
debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.
- Returns:
A GraphStates object containing copies of the sink state.
- Return type:
- max_nodes: ClassVar[int | None]¶
- num_edge_classes: ClassVar[int]¶
- num_node_classes: ClassVar[int]¶
- pad_dim0_with_sf(required_first_dim)¶
Extends a 2-dimensional batch of graph states along the first batch dimension.
Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a GraphStates object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) graphs. This is useful to extend trajectories of different lengths.
- Parameters:
required_first_dim (int) – The size of the first batch dimension post-expansion.
- Return type:
None
- s0: ClassVar[torch_geometric.data.Data]¶
- sf: ClassVar[torch_geometric.data.Data]¶
- classmethod stack(states)¶
Stacks a list of GraphStates objects along a new dimension (0).
- Parameters:
states (List[GraphStates]) – List of GraphStates objects to stack.
- Returns:
A new GraphStates object with the stacked graphs and conditions.
- Return type:
- property tensor: gfn.utils.graphs.GeometricBatch¶
Returns the batch representation of the data as a GeometricBatch.
- Returns:
A GeometricBatch object representing the batch of graphs.
- Return type:
- to(device)¶
Moves the GraphStates to the specified device.
- Parameters:
device (torch.device) – The device to move to.
- Returns:
The GraphStates object on the specified device.
- Return type:
- class gfn.states.States(tensor, conditions=None, device=None, debug=False)¶
Bases:
abc.ABCBase class for states, representing nodes in the DAG of a GFlowNet.
Each environment needs to define a subclass of States. A States object is a collection of multiple states (nodes of the DAG) that supports batching. Generally, if a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute tensor of shape (*batch_shape, *state_shape). Other representations are possible (e.g., state as a string, numpy array, graph, etc.), but these may need additional logic to support batching (see GraphStates below for an example).
Two useful subclasses of States are provided: - DiscreteStates for discrete environments, which represents discrete states
GraphStates for graph-based environments, which represents graphs as a numpy object array of shape (*batch_shape,) containing GeometricData objects.
A batch_shape property keeps track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, batch_size).
Because multiple trajectories can have different lengths, batching requires appending a dummy state (\(sf\)) to trajectories that are shorter than the longest trajectory. This dummy state should never be processed, and is used to pad the batch of states only.
Compile-related expectations: - Hot paths should be called with tensors already on the target device and with
correct shapes; debug guards can be enabled during development/tests to validate.
Set debug=False inside torch.compile regions to avoid Python-side graph breaks; enable debug=True only when running eager checks.
- Parameters:
tensor (torch.Tensor)
conditions (torch.Tensor | None)
device (torch.device | None)
debug (bool)
- state_shape¶
Class variable, a tuple defining the shape of a single state.
- make_random_states¶
Class variable, a callable that returns a random state. This is used to initialize random states.
- __getitem__(index)¶
Returns a subset of the states along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.
- Returns:
A new States object with the selected states and conditions.
- Return type:
- __len__()¶
Returns the number of states in the batch.
- Returns:
The number of states.
- Return type:
int
- __repr__()¶
Returns a string representation of the States object.
- Returns:
A string summary of the States object.
- Return type:
str
- __setitem__(index, states)¶
Sets particular states of the batch to a new States object.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to set.
states (States) – States object containing the new states.
- Return type:
None
- _compare(other)¶
Computes elementwise equality between state tensor and an external tensor.
Note that this does not check if the conditions are equal.
- _conditions: torch.Tensor | None = None¶
- _is_initial_cache: torch.Tensor | None = None¶
- _is_sink_cache: torch.Tensor | None = None¶
- classmethod _make_view(tensor, conditions=None, debug=False)¶
Fast constructor for internal slicing operations.
Bypasses __init__’s device resolution, .to() dispatch, and conditions shape/dtype validation — all redundant when the source States object was already validated. Used by __getitem__ to avoid per-step overhead in the sampling loop.
- Parameters:
tensor (torch.Tensor)
conditions (torch.Tensor | None)
debug (bool)
- Return type:
- _merge_conditions(other, self_was_empty)¶
Merges conditions after extending tensors.
When self was empty its conditions are None by construction, so we adopt the other’s conditions rather than warning about inconsistency. Symmetrically, extending with an empty other is a no-op for conditions.
- Parameters:
other (States)
self_was_empty (bool)
- Return type:
None
- property batch_shape: tuple[int, Ellipsis]¶
The batch shape of the states.
- Returns:
The batch shape as a tuple.
- Return type:
tuple[int, Ellipsis]
- clone()¶
Returns a clone of the current instance.
- Returns:
A new States object with the same data and conditions.
- Return type:
- property conditions: torch.Tensor | None¶
The conditions attached to these states for conditional GFlowNets.
- Returns:
Tensor of shape (*batch_shape, condition_dim) or None if no conditions.
- Return type:
torch.Tensor | None
- debug = False¶
- property device: torch.device¶
The device on which the states are stored.
- Returns:
The device of the underlying tensor.
- Return type:
torch.device
- extend(other)¶
Concatenates another States object along the final batch dimension.
Both States objects must have the same number of batch dimensions, which should be 1 or 2.
- Parameters:
other (States) – States object to be concatenated to the current States object.
- Return type:
None
- flatten()¶
Flattens the batch dimension of the states.
Useful for example when extracting individual states from trajectories.
- Returns:
A new States object with the batch dimension flattened.
- Return type:
- classmethod from_batch_shape(batch_shape, random=False, sink=False, conditions=None, device=None, debug=False)¶
Creates a States object with the given batch shape.
By default, all states are initialized to \(s_0\), the initial state. Optionally, one can initialize random state, which requires that the environment implements the make_random_states class method. Sink can be used to initialize states at \(s_f\), the sink state. Both random and sink cannot be True at the same time.
- Parameters:
batch_shape (int | tuple[int, Ellipsis]) – Shape of the batch dimensions.
random (bool) – If True, initialize states randomly.
sink (bool) – If True, initialize states as sink states (\(s_f\)).
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the states on.
debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.
- Returns:
A States object with the specified batch shape and initialization.
- Return type:
- property has_conditions: bool¶
Whether conditions are attached to these states.
- Return type:
bool
- property is_initial_state: torch.Tensor¶
Returns a boolean tensor indicating which states are initial (\(s_0\)).
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for initial states.
- Return type:
torch.Tensor
- property is_sink_state: torch.Tensor¶
Returns a boolean tensor indicating which states are sink (\(s_f\)).
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for sink states.
- Return type:
torch.Tensor
- classmethod make_initial_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a States object with all states set to \(s_0\).
- Parameters:
batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the states on.
debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.
- Returns:
A States object with all states set to \(s_0\).
- Return type:
- make_random_states: Callable¶
- classmethod make_sink_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a States object with all states set to \(s_f\).
- Parameters:
batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to create the states on.
debug (bool) – If True, keeps compile graph-breaking checks in the logic for safety.
- Returns:
A States object with all states set to \(s_f\).
- Return type:
- pad_dim0_with_sf(required_first_dim)¶
Extends a 2-dimensional batch of states along the first batch dimension.
Given a batch of states (i.e. of batch_shape=(a, b)), extends a to a States object of batch_shape = (required_first_dim, b), by adding the required number of \(s_f\) tensors. This is useful to extend trajectories of different lengths.
- Parameters:
required_first_dim (int) – The size of the first batch dimension post-expansion.
- Return type:
None
- s0: ClassVar[torch.Tensor | torch_geometric.data.Data]¶
- sample(n_samples)¶
Randomly samples a subset of states from the batch.
- Parameters:
n_samples (int) – The number of states to sample.
- Returns:
A new States object with the sampled states.
- Return type:
- sf: ClassVar[torch.Tensor | torch_geometric.data.Data]¶
- classmethod stack(states)¶
Stacks a list of States objects along a new dimension (0).
- state_shape: ClassVar[tuple[int, Ellipsis]]¶
- tensor¶
- gfn.states._assert_factory_accepts_debug(factory, factory_name)¶
Ensure the factory can accept a debug kwarg (explicit or via **kwargs).
- Parameters:
factory (Callable)
factory_name (str)
- Return type:
None
- gfn.states.logger¶