gfn.containers.replay_buffer

Attributes

ContainerUnion

Classes

Container

Base class for protocol classes.

NormBasedDiversePrioritizedReplayBuffer

A replay buffer with diversity-based prioritization.

ReplayBuffer

A replay buffer for storing training containers.

TerminatingStateBuffer

A replay buffer for storing terminating states.

Module Contents

class gfn.containers.replay_buffer.Container

Bases: Protocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__getitem__(idx)
__len__()
Return type:

int

extend(other)
property log_rewards: torch.Tensor | None
Return type:

torch.Tensor | None

property terminating_states
gfn.containers.replay_buffer.ContainerUnion
class gfn.containers.replay_buffer.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)

Bases: ReplayBuffer

A replay buffer with diversity-based prioritization.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • cutoff_distance (float)

  • p_norm_distance (float)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

  • communication_backend (str)

  • timing (bool)

  • async_score (bool)

  • async_comm (bool)

  • lazy_sort (bool)

  • baseline_filtering (bool)

  • scoring_only (bool)

  • baseline_refresh_after (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.

prioritized_capacity

Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.

prioritized_sampling

Whether to sample items with probability proportional to their reward.

cutoff_distance

Threshold used to determine whether a new terminating state is different enough from those already in the buffer.

p_norm_distance

p-norm value for distance calculation (used in torch.cdist).

static _diversity_repr(container)

Returns the tensor used for pairwise distance in diversity filtering.

For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.

Parameters:

container (ContainerUnion)

Return type:

torch.Tensor

_local_add(training_container)

Adds with diversity-based prioritization to the local buffer.

Overrides the base class hook so that add() (which handles remote communication) delegates local insertion here.

Parameters:

training_container (ContainerUnion)

cutoff_distance = 0.0
p_norm_distance = 1.0
class gfn.containers.replay_buffer.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)

A replay buffer for storing training containers.

Supports local-only operation and distributed remote buffer communication.

Features:
  • Local buffering: Stores Trajectories, Transitions, or StatesContainers up to a fixed capacity.

  • Prioritized capacity: Optionally keeps only the highest-reward items when the buffer is full.

  • Prioritized sampling: Optionally samples with probability proportional to reward (softmax over log-rewards).

  • Remote buffer communication: When remote_manager_rank is set, periodically sends batched containers to a remote ReplayBufferManager and receives score dictionaries back.

  • Communication backends: The communication_backend parameter selects between "torch" (PyTorch distributed / Gloo) and "mpi" (MPI4PY, ~8-12 GB/s vs ~100 MB/s with Gloo).

  • Async scoring: When async_score is enabled, trajectory sends are fire-and-forget; scores are collected lazily on the next add() call (1-iteration stale), decoupling training throughput from buffer scoring latency.

  • Timing instrumentation: When timing is enabled, serialization, send, and receive durations are recorded for profiling.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • prioritized_capacity (bool)

  • prioritized_sampling (bool)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

  • communication_backend (str)

  • timing (bool)

  • async_score (bool)

  • async_comm (bool)

  • lazy_sort (bool)

  • baseline_filtering (bool)

  • scoring_only (bool)

  • baseline_refresh_after (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). Dynamically set based on the type of the first added object.

__len__()

Returns the number of items in the buffer.

Returns:

The number of items in the buffer (including pending batches).

Return type:

int

__repr__()

Returns a string representation of the ReplayBuffer.

Returns:

A string summary of the buffer.

Return type:

str

_add_counter = 0
_baseline_kept: int = 0
_baseline_log_reward: float
_baseline_skipped_sends: int = 0
_baseline_total: int = 0
_collect_pending_score()

Collect a pending score response from a previous async send.

Returns None if no score is pending (e.g., first iteration).

Return type:

dict[str, float] | None

_consecutive_filtered_empty: int = 0
_filter_and_send(container, send_fn)

Filter by baseline, prepare for remote, and send.

Returns whatever send_fn returns (a score dict for sync sends, None for async sends), or None if baseline filtering drops everything.

_filter_by_baseline(container)

Filter a container to keep only items with log_reward >= baseline.

Returns the (possibly subset) container, or None if every item is below the baseline. After baseline_refresh_after consecutive fully-filtered batches, the next batch bypasses the filter so the worker can receive a fresh baseline. Transitions is not supported (its log_rewards is per-transition with -inf for non-terminating rows, so per-row filtering would break DB/SubTB).

Parameters:

container (ContainerUnion)

Return type:

ContainerUnion | None

_flush_pending()

Concatenate all pending batches into training_container.

Called lazily when the accumulated size reaches 2 * capacity, or eagerly by callers that need a consistent view.

Merges all pending batches into a single combined batch first, then extends training_container once to avoid extra copy cost.

Return type:

None

_is_full = False
_isend_and_defer_score(training_container)

Non-blocking send (isend), deferred score: fire-and-forget data, collect score on next add().

The send handle is kept alive in _send_handle until the next call to _wait_previous_send.

Parameters:

training_container (ContainerUnion)

Return type:

None

_local_add(training_container)

Adds a training object to the local buffer, handling capacity.

Subclasses override this to customize local insertion logic (e.g., diversity filtering). The base class add() calls this method, then handles remote buffer communication separately.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

_pending_batches: list[ContainerUnion] = []
_pending_len: int = 0
_pending_score: bool = False
_prepare_for_remote(container)

Convert a container to a lightweight form for remote scoring.

When scoring_only is True, extracts terminating states and log-rewards into a StatesContainer. Transitions is rejected because its log_rewards shape does not match terminating_states (it is per-transition, not per-trajectory). When scoring_only is False, returns the container unchanged.

Parameters:

container (ContainerUnion)

Return type:

ContainerUnion

_recv_score()

Receive a score dictionary from the remote manager.

Return type:

dict[str, float]

_send_data(training_container)

Send a training container to the remote manager.

Parameters:

training_container (ContainerUnion)

Return type:

None

_send_handle: gfn.utils.distributed.AsyncSendHandle | None = None
_send_objs(training_container)

Sends a training container to the remote manager (synchronous).

Parameters:

training_container (ContainerUnion)

Return type:

dict[str, float]

_send_objs_async(training_container)

Sends a training container without waiting for the score response.

The score will be collected on the next call to _collect_pending_score.

Parameters:

training_container (ContainerUnion)

Return type:

None

_sort_and_truncate(training_container)

Sort by log-reward (if prioritized) and truncate to capacity.

Parameters:

training_container (ContainerUnion)

Return type:

None

_update_baseline(score_dict)

Extract and store the baseline log-reward from a score response.

Called after receiving a score dict from the buffer manager. Only updates if baseline_filtering is enabled and the score dict contains a baseline_log_reward key.

Parameters:

score_dict (dict[str, float] | None)

Return type:

None

_wait_previous_send()

Block until the previous non-blocking send has completed.

This is typically near-instantaneous because MPI internally buffers the data, but guarantees the send buffer can be safely reused.

Return type:

None

add(training_container)

Adds a training container to the buffer.

The type of the training container is dynamically set based on the type of the first added container.

When async_score is enabled, scores are collected lazily: the first call returns None (no pending score yet), and subsequent calls return the score from the previous submission. This decouples training throughput from buffer scoring latency.

When baseline_filtering is enabled, only trajectories with log-reward above the remote buffer’s baseline are sent. If all trajectories in the pending batch are below the baseline, the send is skipped entirely.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

Return type:

dict[str, float] | None

async_comm = False
async_score = False
baseline_filtering = False
baseline_refresh_after = 10
capacity = 1000
communication_backend = 'mpi'
property device: torch.device

The device on which the buffer’s data is stored.

Returns:

The device object of the buffer’s contents.

Return type:

torch.device

drain_pending_score(timeout_sec=30.0)

Drain any outstanding async score before shutdown.

Should be called before sending the EXIT signal when async_score or async_comm is enabled, to avoid leaving the buffer manager with an undelivered response.

For async_comm mode this also waits for the outstanding non-blocking send to complete.

Uses a timeout to avoid hanging indefinitely if the buffer manager has crashed. Returns None on timeout (score is lost).

Parameters:

timeout_sec (float)

Return type:

dict[str, float] | None

env
initialize(training_container)

Initializes the buffer with the type of the first added object.

Parameters:

training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.

Return type:

None

lazy_sort = False
load(path)

Loads buffer contents from a .pt file saved by save().

Parameters:

path (str) – File path to the saved buffer.

pending_container: ContainerUnion | None = None
prioritized_capacity = False
prioritized_sampling = False
remote_buffer_freq = 1
remote_manager_rank = None
sample(n_samples)

Samples training objects from the buffer.

Parameters:

n_samples (int) – The number of items to sample.

Returns:

A sampled Trajectories, Transitions, or StatesContainer.

Return type:

ContainerUnion

save(path)

Saves the buffer to a single .pt file.

Parameters:

path (str) – File path (e.g. "replay_buffer.pt").

scoring_only = False
timing = False
timing_data: dict[str, list[float]]
timing_log()

Returns a formatted string of the timing information for the replay buffer.

Return type:

str

training_container: ContainerUnion | None = None
class gfn.containers.replay_buffer.TerminatingStateBuffer(env, capacity=1000, communication_backend='mpi', timing=False, **kwargs)

Bases: ReplayBuffer

A replay buffer for storing terminating states.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • communication_backend (str)

  • timing (bool)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (StatesContainer).

_local_add(training_container)

Extracts terminating states and adds them to the local buffer.

Overrides the base class hook so that add() (which handles remote communication) delegates local insertion here.

Parameters:

training_container (ContainerUnion)

training_container