States, Actions, and Containers¶
States¶
States are the primitive building blocks for GFlowNet objects such as transitions and trajectories, on which losses operate.
An abstract States class is provided. But for each environment, a States subclass is needed. A States object
is a collection of multiple states (nodes of the DAG). A tensor representation of the states is required for batching. If a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a States object, with the attribute tensor of shape (*batch_shape, *state_shape). Other
representations are possible (e.g. a state as a string, a numpy array, a graph, etc…), but these representations cannot be batched, unless the user specifies a function that transforms these raw states to tensors.
The batch_shape attribute is required to keep track of the batch dimension. A trajectory can be represented by a States object with batch_shape = (n_states,). Multiple trajectories can be represented by a States object with batch_shape = (n_states, n_trajectories).
Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the \(s_f\) attribute of the environment (e.g. [-1, ..., -1], or [-inf, ..., -inf], etc…). Which is never processed, and is used to pad the batch of states only.
For discrete environments, the action set is represented with the set \({0, \dots, n_{actions} - 1}\), where the \((n_{actions})\)-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type \(s \rightarrow s_f\), but not all actions are possible at all states. For discrete environments, each States object is endowed with two extra attributes: forward_masks and backward_masks, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the DiscreteStates abstract subclass of States. The forward_masks tensor is of shape (*batch_shape, n_{actions}), and backward_masks is of shape (*batch_shape, n_{actions} - 1). Each subclass of DiscreteStates needs to implement the update_masks function, that uses the environment’s logic to define the two tensors.
Debug guards and factory signatures¶
To keep compiled hot paths fast, States/DiscreteStates/GraphStates expect a debug flag passed at construction time. When debug=False (default) no Python-side checks run in hot paths; when debug=True, shape/device/type guards run to catch silent bugs. Environments carry an env-level debug and pass it when they instantiate States.
When defining your own States subclass or environment factories, make sure all state factories accept debug:
Constructors:
__init__(..., debug: bool = False, ...)should storeself.debugand pass it along when cloning or slicing.Factory classmethods:
make_random_states,make_initial_states,make_sink_states(and any overrides) must acceptdebug(or**kwargs) and forward it toStates(...). The base class enforces this and will raise a clearTypeErrorotherwise.Env helpers: if you override
states_from_tensororresetin an environment, threadself.debuginto state construction so all emitted states share the env-level setting.
This pattern avoids graph breaks in torch.compile by letting you keep debug=False in compiled runs while still enabling strong checks in development and tests.
Actions¶
Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions \(s \rightarrow s'\). An abstract Actions class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise.
Similar to States objects, each action is a tensor of shape (*batch_shape, *action_shape). For discrete environments for instances, action_shape = (1,), representing an integer between \(0\) and \(n_{actions} - 1\).
Additionally, each subclass needs to define two more class variable tensors:
dummy_action: A tensor that is padded to sequences of actions in the shorter trajectories of a batch of trajectories. It is[-1]for discrete environments.exit_action: A tensor that corresponds to the termination action. It is[n_{actions} - 1]fo discrete environments.
Debug guards and factory signatures¶
Actions mirrors the States pattern: constructors and factories accept debug: bool = False. Keep debug=False in compiled/hot paths to avoid Python-side asserts; flip it on in development/tests to run shape/device validations. When defining custom subclasses, ensure:
__init__(..., debug: bool = False, ...)storesself.debugand only runs validations whendebugis True.Factory classmethods (
make_dummy_actions,make_exit_actions, helpers likefrom_tensor_dict) acceptdebug(or**kwargs) and forward it to the constructor.Environment helpers (
actions_from_tensor,actions_from_batch_shape) should thread the env-leveldebugso all emitted actions share the setting.
Containers¶
Containers are collections of States, along with other information, such as reward values, or densities \(p(s' \mid s)\). Three containers are available:
Transitions, representing a batch of transitions \(s \rightarrow s'\).
Trajectories, representing a batch of complete trajectories \(\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f\).
StatesContainer, representing a batch of states, particularly useful for flow-matching GFlowNet.
These containers can either be instantiated using a States object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the ReplayBuffer class.
They inherit from the base Container class, indicating some helpful methods.
In most cases, one needs to sample complete trajectories. From a batch of trajectories, various training samples can be generated:
Use
Trajectories.to_transitions()andTrajectories.to_states()for edge-decomposable or state-decomposable lossesUse
Trajectories.to_state_pairs()for flow matching lossesUse
GFlowNet.loss_from_trajectories()as a convenience method that handles the conversion internally
These methods exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching.