gfn.containers

Submodules

Classes

Container

Base class for state containers (states, transitions, or trajectories).

NormBasedDiversePrioritizedReplayBuffer

A replay buffer with diversity-based prioritization.

ReplayBuffer

A replay buffer for storing training containers.

StatesContainer

Container for a batch of states (mainly used for FMGFlowNet).

TerminatingStateBuffer

A replay buffer for storing terminating states.

Trajectories

Container for complete trajectories (starting in $s_0$ and ending in $s_f$).

Transitions

Container for a batch of transitions.

Package Contents

class gfn.containers.Container

Bases: abc.ABC

Base class for state containers (states, transitions, or trajectories).

abstract __getitem__(index)

Returns a subset of the container based on the provided index.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – An integer, slice, tuple, sequence of indices or booleans, or a torch.Tensor specifying which elements to select.

Returns:

A new container containing the selected elements and associated data.

Return type:

Container

abstract __len__()

Returns the number of elements in the container.

Returns:

The number of elements in the container.

Return type:

int

property device: torch.device
Abstractmethod:

Return type:

torch.device

The device on which the container is stored.

Returns:

The device on which the container is stored.

Return type:

torch.device

abstract extend(other)

Extends the current container with elements from another container object.

Parameters:

other (Container) – The other container whose elements will be added.

Return type:

None

classmethod from_tensordict(env, td)
Abstractmethod:

Parameters:
  • env (gfn.env.Env)

  • td (tensordict.base.TensorDictBase)

Return type:

Container

Reconstruct a container from a TensorDict.

Parameters:
  • env (gfn.env.Env) – The environment needed to reconstruct States/Actions.

  • td (tensordict.base.TensorDictBase) – The TensorDict produced by to_tensordict().

Returns:

A new container instance.

Return type:

Container

property has_log_probs: bool

Whether the container has log probabilities.

Returns:

True if log probabilities are present and non-empty, False otherwise.

Return type:

bool

classmethod load(env, path)

Loads a container from a .pt file saved by save().

Parameters:
  • env (gfn.env.Env) – The environment needed to reconstruct States/Actions.

  • path (str) – File path to the saved container.

Returns:

A new container instance.

Return type:

Container

property log_rewards: torch.Tensor
Abstractmethod:

Return type:

torch.Tensor

The log rewards associated with the container.

Returns:

The log rewards tensor.

Return type:

torch.Tensor

sample(n_samples)

Randomly samples a subset of elements from the container.

Parameters:

n_samples (int) – The number of elements to sample.

Returns:

A new container with the sampled elements.

Return type:

Container

save(path)

Saves the container to a single .pt file.

Parameters:

path (str) – File path (e.g. "trajectories.pt").

Return type:

None

property terminating_states: gfn.states.States
Abstractmethod:

Return type:

gfn.states.States

The last (terminating) states of the container.

Returns:

The terminating states.

Return type:

gfn.states.States

abstract to_tensordict()

Serialize the container’s data into a TensorDict.

Returns:

A TensorDict containing all tensor data and scalar metadata. The env reference is not included; it must be supplied when reconstructing via from_tensordict().

Return type:

tensordict.base.TensorDictBase

class gfn.containers.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)

Bases: ReplayBuffer

A replay buffer with diversity-based prioritization.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • cutoff_distance (float)

  • p_norm_distance (float)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

  • communication_backend (str)

  • timing (bool)

  • async_score (bool)

  • async_comm (bool)

  • lazy_sort (bool)

  • baseline_filtering (bool)

  • scoring_only (bool)

  • baseline_refresh_after (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.

prioritized_capacity

Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.

prioritized_sampling

Whether to sample items with probability proportional to their reward.

cutoff_distance

Threshold used to determine whether a new terminating state is different enough from those already in the buffer.

p_norm_distance

p-norm value for distance calculation (used in torch.cdist).

static _diversity_repr(container)

Returns the tensor used for pairwise distance in diversity filtering.

For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.

Parameters:

container (ContainerUnion)

Return type:

torch.Tensor

_local_add(training_container)

Adds with diversity-based prioritization to the local buffer.

Overrides the base class hook so that add() (which handles remote communication) delegates local insertion here.

Parameters:

training_container (ContainerUnion)

cutoff_distance = 0.0
p_norm_distance = 1.0
class gfn.containers.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)

A replay buffer for storing training containers.

Supports local-only operation and distributed remote buffer communication.

Features:
  • Local buffering: Stores Trajectories, Transitions, or StatesContainers up to a fixed capacity.

  • Prioritized capacity: Optionally keeps only the highest-reward items when the buffer is full.

  • Prioritized sampling: Optionally samples with probability proportional to reward (softmax over log-rewards).

  • Remote buffer communication: When remote_manager_rank is set, periodically sends batched containers to a remote ReplayBufferManager and receives score dictionaries back.

  • Communication backends: The communication_backend parameter selects between "torch" (PyTorch distributed / Gloo) and "mpi" (MPI4PY, ~8-12 GB/s vs ~100 MB/s with Gloo).

  • Async scoring: When async_score is enabled, trajectory sends are fire-and-forget; scores are collected lazily on the next add() call (1-iteration stale), decoupling training throughput from buffer scoring latency.

  • Timing instrumentation: When timing is enabled, serialization, send, and receive durations are recorded for profiling.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • prioritized_capacity (bool)

  • prioritized_sampling (bool)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

  • communication_backend (str)

  • timing (bool)

  • async_score (bool)

  • async_comm (bool)

  • lazy_sort (bool)

  • baseline_filtering (bool)

  • scoring_only (bool)

  • baseline_refresh_after (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). Dynamically set based on the type of the first added object.

__len__()

Returns the number of items in the buffer.

Returns:

The number of items in the buffer (including pending batches).

Return type:

int

__repr__()

Returns a string representation of the ReplayBuffer.

Returns:

A string summary of the buffer.

Return type:

str

_add_counter = 0
_baseline_kept: int = 0
_baseline_log_reward: float
_baseline_skipped_sends: int = 0
_baseline_total: int = 0
_collect_pending_score()

Collect a pending score response from a previous async send.

Returns None if no score is pending (e.g., first iteration).

Return type:

dict[str, float] | None

_consecutive_filtered_empty: int = 0
_filter_and_send(container, send_fn)

Filter by baseline, prepare for remote, and send.

Returns whatever send_fn returns (a score dict for sync sends, None for async sends), or None if baseline filtering drops everything.

_filter_by_baseline(container)

Filter a container to keep only items with log_reward >= baseline.

Returns the (possibly subset) container, or None if every item is below the baseline. After baseline_refresh_after consecutive fully-filtered batches, the next batch bypasses the filter so the worker can receive a fresh baseline. Transitions is not supported (its log_rewards is per-transition with -inf for non-terminating rows, so per-row filtering would break DB/SubTB).

Parameters:

container (ContainerUnion)

Return type:

ContainerUnion | None

_flush_pending()

Concatenate all pending batches into training_container.

Called lazily when the accumulated size reaches 2 * capacity, or eagerly by callers that need a consistent view.

Merges all pending batches into a single combined batch first, then extends training_container once to avoid extra copy cost.

Return type:

None

_is_full = False
_isend_and_defer_score(training_container)

Non-blocking send (isend), deferred score: fire-and-forget data, collect score on next add().

The send handle is kept alive in _send_handle until the next call to _wait_previous_send.

Parameters:

training_container (ContainerUnion)

Return type:

None

_local_add(training_container)

Adds a training object to the local buffer, handling capacity.

Subclasses override this to customize local insertion logic (e.g., diversity filtering). The base class add() calls this method, then handles remote buffer communication separately.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

_pending_batches: list[ContainerUnion] = []
_pending_len: int = 0
_pending_score: bool = False
_prepare_for_remote(container)

Convert a container to a lightweight form for remote scoring.

When scoring_only is True, extracts terminating states and log-rewards into a StatesContainer. Transitions is rejected because its log_rewards shape does not match terminating_states (it is per-transition, not per-trajectory). When scoring_only is False, returns the container unchanged.

Parameters:

container (ContainerUnion)

Return type:

ContainerUnion

_recv_score()

Receive a score dictionary from the remote manager.

Return type:

dict[str, float]

_send_data(training_container)

Send a training container to the remote manager.

Parameters:

training_container (ContainerUnion)

Return type:

None

_send_handle: gfn.utils.distributed.AsyncSendHandle | None = None
_send_objs(training_container)

Sends a training container to the remote manager (synchronous).

Parameters:

training_container (ContainerUnion)

Return type:

dict[str, float]

_send_objs_async(training_container)

Sends a training container without waiting for the score response.

The score will be collected on the next call to _collect_pending_score.

Parameters:

training_container (ContainerUnion)

Return type:

None

_sort_and_truncate(training_container)

Sort by log-reward (if prioritized) and truncate to capacity.

Parameters:

training_container (ContainerUnion)

Return type:

None

_update_baseline(score_dict)

Extract and store the baseline log-reward from a score response.

Called after receiving a score dict from the buffer manager. Only updates if baseline_filtering is enabled and the score dict contains a baseline_log_reward key.

Parameters:

score_dict (dict[str, float] | None)

Return type:

None

_wait_previous_send()

Block until the previous non-blocking send has completed.

This is typically near-instantaneous because MPI internally buffers the data, but guarantees the send buffer can be safely reused.

Return type:

None

add(training_container)

Adds a training container to the buffer.

The type of the training container is dynamically set based on the type of the first added container.

When async_score is enabled, scores are collected lazily: the first call returns None (no pending score yet), and subsequent calls return the score from the previous submission. This decouples training throughput from buffer scoring latency.

When baseline_filtering is enabled, only trajectories with log-reward above the remote buffer’s baseline are sent. If all trajectories in the pending batch are below the baseline, the send is skipped entirely.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

Return type:

dict[str, float] | None

async_comm = False
async_score = False
baseline_filtering = False
baseline_refresh_after = 10
capacity = 1000
communication_backend = 'mpi'
property device: torch.device

The device on which the buffer’s data is stored.

Returns:

The device object of the buffer’s contents.

Return type:

torch.device

drain_pending_score(timeout_sec=30.0)

Drain any outstanding async score before shutdown.

Should be called before sending the EXIT signal when async_score or async_comm is enabled, to avoid leaving the buffer manager with an undelivered response.

For async_comm mode this also waits for the outstanding non-blocking send to complete.

Uses a timeout to avoid hanging indefinitely if the buffer manager has crashed. Returns None on timeout (score is lost).

Parameters:

timeout_sec (float)

Return type:

dict[str, float] | None

env
initialize(training_container)

Initializes the buffer with the type of the first added object.

Parameters:

training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.

Return type:

None

lazy_sort = False
load(path)

Loads buffer contents from a .pt file saved by save().

Parameters:

path (str) – File path to the saved buffer.

pending_container: ContainerUnion | None = None
prioritized_capacity = False
prioritized_sampling = False
remote_buffer_freq = 1
remote_manager_rank = None
sample(n_samples)

Samples training objects from the buffer.

Parameters:

n_samples (int) – The number of items to sample.

Returns:

A sampled Trajectories, Transitions, or StatesContainer.

Return type:

ContainerUnion

save(path)

Saves the buffer to a single .pt file.

Parameters:

path (str) – File path (e.g. "replay_buffer.pt").

scoring_only = False
timing = False
timing_data: dict[str, list[float]]
timing_log()

Returns a formatted string of the timing information for the replay buffer.

Return type:

str

training_container: ContainerUnion | None = None
class gfn.containers.StatesContainer(env, states=None, is_terminating=None, log_rewards=None)

Bases: gfn.containers.base.Container, Generic[StateType]

Container for a batch of states (mainly used for FMGFlowNet).

This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms.

Parameters:
  • env (gfn.env.Env)

  • states (StateType | None)

  • is_terminating (torch.Tensor | None)

  • log_rewards (torch.Tensor | None)

env

The environment where the states are defined.

states

States with batch_shape (n_states,).

is_terminating

Boolean tensor of shape (n_states,) indicating which states are terminating.

_log_rewards

(Optional) Tensor of shape (n_states,) containing the log rewards for terminating states.

__getitem__(index)

Returns a subset of the states along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.

Returns:

A new StatesContainer with the selected states and associated data.

Return type:

StatesContainer[StateType]

__len__()

Returns the number of states in the container.

Returns:

The number of states.

Return type:

int

__repr__()

Returns a string representation of the StatesContainer.

Returns:

A string summary of the container.

Return type:

str

_log_rewards = None
property device: torch.device

The device on which the states are stored.

Returns:

The device object of the self.states.

Return type:

torch.device

env
extend(other)

Extends this container with another StatesContainer object.

Parameters:
Return type:

None

classmethod from_tensordict(env, td)

Reconstruct a StatesContainer from a TensorDict.

Parameters:
  • env (gfn.env.Env)

  • td (gfn.containers.base.TensorDictBase)

Return type:

StatesContainer[StateType]

property intermediary_states: StateType

The intermediary states (not initial states) of the StatesContainer.

Returns:

The intermediary states.

Return type:

StateType

is_terminating
property log_rewards: torch.Tensor

The log rewards for all states.

Returns:

Log rewards tensor of shape (len(self.states),). Intermediate states have

value -inf.

Return type:

torch.Tensor

Note

If not provided at initialization, log rewards are computed on demand for terminating states.

states
property terminating_log_rewards: torch.Tensor

The log rewards for terminating states only.

Returns:

The log rewards for terminating states.

Return type:

torch.Tensor

property terminating_states: StateType

The last (terminating) states of the StatesContainer.

Returns:

The terminating states.

Return type:

StateType

to_tensordict()

Serialize the states container into a TensorDict.

Return type:

gfn.containers.base.TensorDictBase

class gfn.containers.TerminatingStateBuffer(env, capacity=1000, communication_backend='mpi', timing=False, **kwargs)

Bases: ReplayBuffer

A replay buffer for storing terminating states.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • communication_backend (str)

  • timing (bool)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (StatesContainer).

_local_add(training_container)

Extracts terminating states and adds them to the local buffer.

Overrides the base class hook so that add() (which handles remote communication) delegates local insertion here.

Parameters:

training_container (ContainerUnion)

training_container
class gfn.containers.Trajectories(env, states=None, actions=None, terminating_idx=None, is_backward=False, log_rewards=None, log_probs=None, estimator_outputs=None)

Bases: gfn.containers.base.Container

Container for complete trajectories (starting in \(s_0\) and ending in \(s_f\)).

Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state (\(s_f\) or \(s_0\) depending on the direction of the trajectory), and actions is appended with dummy actions. The terminating_idx tensor represents the time step at which each trajectory ends.

Parameters:
  • env (gfn.env.Env)

  • states (gfn.states.States | None)

  • actions (gfn.actions.Actions | None)

  • terminating_idx (torch.Tensor | None)

  • is_backward (bool)

  • log_rewards (torch.Tensor | None)

  • log_probs (torch.Tensor | None)

  • estimator_outputs (torch.Tensor | None)

env

The environment where the states and actions are defined.

states

States with batch_shape (max_length+1, batch_size).

actions

Actions with batch_shape (max_length, batch_size).

terminating_idx

Tensor of shape (batch_size,) indicating the time step at which each trajectory ends.

is_backward

Whether the trajectories are backward or forward. When not is_backward, the states are ordered from initial to terminal states. When is_backward, the states are ordered from terminal to initial states.

_log_rewards

(Optional) Tensor of shape (batch_size,) containing the log rewards of the trajectories.

log_probs

(Optional) Tensor of shape (max_length, batch_size) indicating the log probabilities of the trajectories’ actions.

estimator_outputs

(Optional) Tensor of shape (max_length, batch_size, …) containing outputs of a function approximator for each step.

__getitem__(index)

Returns a subset of the trajectories along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select trajectories.

Returns:

A new Trajectories object with the selected trajectories and associated data.

Return type:

Trajectories

__len__()

Returns the number of trajectories in the container.

Returns:

The number of trajectories.

Return type:

int

__repr__()

Returns a string representation of the Trajectories container.

Returns:

A string summary of the trajectories.

Return type:

str

_log_rewards = None
actions
property batch_size: int

The number of trajectories in the container.

Returns:

The number of trajectories.

Return type:

int

property device: torch.device

The device on which the trajectories are stored.

Returns:

The device object of the self.states.

Return type:

torch.device

env
estimator_outputs = None
extend(other)

Extends this Trajectories object with another Trajectories object.

Extends along all attributes in turn (actions, states, terminating_idx, log_probs, log_rewards).

Parameters:

other (Trajectories) – Another Trajectories to append.

Return type:

None

classmethod from_tensordict(env, td)

Reconstruct Trajectories from a TensorDict.

Parameters:
  • env (gfn.env.Env)

  • td (gfn.containers.base.TensorDictBase)

Return type:

Trajectories

is_backward = False
log_probs = None
property log_rewards: torch.Tensor | None

The log rewards for the trajectories.

Returns:

Log rewards tensor of shape (batch_size,).

Return type:

torch.Tensor | None

Note

If not provided at initialization, log rewards are computed on demand for terminating states.

property max_length: int

The maximum length of the trajectories in the container.

Returns:

The maximum trajectory length.

Return type:

int

property n_trajectories: int

Deprecated alias for batch_size.

Return type:

int

reverse_backward_trajectories()

Returns a reversed version of the backward trajectories.

Return type:

Trajectories

states
terminating_idx
property terminating_states: gfn.states.States

The terminating states of the trajectories.

Returns:

The terminating states.

Return type:

gfn.states.States

to_states_container()

Returns a StatesContainer object from the current Trajectories.

Returns:

A StatesContainer object with the same states, actions, and log_rewards as the current Trajectories.

Return type:

gfn.containers.states_container.StatesContainer

to_tensordict()

Serialize trajectories into a TensorDict.

Return type:

gfn.containers.base.TensorDictBase

to_transitions()

Returns a Transitions object from the current Trajectories.

Returns:

A Transitions object with the same states, actions, and log_rewards as the current Trajectories.

Return type:

gfn.containers.transitions.Transitions

class gfn.containers.Transitions(env, states=None, actions=None, is_terminating=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Container for a batch of transitions.

This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties.

Parameters:
env

The environment where the states and actions are defined.

states

States with batch_shape (n_transitions,).

actions

Actions with batch_shape (n_transitions,). The actions make the transitions from the states to the next_states.

is_terminating

Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action.

next_states

States with batch_shape (n_transitions,).

is_backward

Whether the transitions are backward transitions. When not is_backward, the states are the parents of the transitions and the next_states are the children. When is_backward, the states are the children of the transitions and the next_states are the parents.

_log_rewards

(Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions.

log_probs

(Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions.

__getitem__(index)

Returns a subset of the transitions along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select transitions.

Returns:

A new Transitions object with the selected transitions and associated data.

Return type:

Transitions

__len__()

Returns the number of transitions in the container.

Returns:

The number of transitions.

Return type:

int

__repr__()

Returns a string representation of the Transitions container.

Returns:

A string summary of the transitions.

_log_rewards = None
actions
property all_log_rewards: torch.Tensor

A helper method to compute the log rewards for all transitions

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Returns:

Log rewards tensor of shape (n_transitions, 2) for the transitions.

Return type:

torch.Tensor

property device: torch.device

The device on which the transitions are stored.

Returns:

The device object of the self.states.

Return type:

torch.device

env
extend(other)

Extends this Transitions object with another Transitions object.

Parameters:
  • append. (Another Transitions object to)

  • other (Transitions)

Return type:

None

classmethod from_tensordict(env, td)

Reconstruct Transitions from a TensorDict.

Parameters:
  • env (gfn.env.Env)

  • td (gfn.containers.base.TensorDictBase)

Return type:

Transitions

is_backward = False
is_terminating
log_probs = None
property log_rewards: torch.Tensor | None

The log rewards for the transitions.

Returns:

Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf.

Return type:

torch.Tensor | None

Note

If not provided at initialization, log rewards are computed on demand for terminating transitions.

property n_transitions: int

The number of transitions in the container.

Returns:

The number of transitions.

Return type:

int

next_states
states
property terminating_states: gfn.states.States

The terminating states of the transitions.

Returns:

The terminating states.

Return type:

gfn.states.States

to_tensordict()

Serialize transitions into a TensorDict.

Return type:

gfn.containers.base.TensorDictBase