gfn.containers.replay_buffer¶
Attributes¶
Classes¶
Base class for protocol classes. |
|
A replay buffer with diversity-based prioritization. |
|
A replay buffer for storing training containers. |
|
A replay buffer for storing terminating states. |
Module Contents¶
- class gfn.containers.replay_buffer.Container¶
Bases:
ProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __getitem__(idx)¶
- __len__()¶
- Return type:
int
- extend(other)¶
- property log_rewards: torch.Tensor | None¶
- Return type:
torch.Tensor | None
- property terminating_states¶
- gfn.containers.replay_buffer.ContainerUnion¶
- class gfn.containers.replay_buffer.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)¶
Bases:
ReplayBufferA replay buffer with diversity-based prioritization.
- Parameters:
env (gfn.env.Env)
capacity (int)
cutoff_distance (float)
p_norm_distance (float)
remote_manager_rank (int | None)
remote_buffer_freq (int)
communication_backend (str)
timing (bool)
async_score (bool)
async_comm (bool)
lazy_sort (bool)
baseline_filtering (bool)
scoring_only (bool)
baseline_refresh_after (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.
- prioritized_capacity¶
Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.
- prioritized_sampling¶
Whether to sample items with probability proportional to their reward.
- cutoff_distance¶
Threshold used to determine whether a new terminating state is different enough from those already in the buffer.
- p_norm_distance¶
p-norm value for distance calculation (used in torch.cdist).
- static _diversity_repr(container)¶
Returns the tensor used for pairwise distance in diversity filtering.
For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.
- Parameters:
container (ContainerUnion)
- Return type:
torch.Tensor
- _local_add(training_container)¶
Adds with diversity-based prioritization to the local buffer.
Overrides the base class hook so that
add()(which handles remote communication) delegates local insertion here.- Parameters:
training_container (ContainerUnion)
- cutoff_distance = 0.0¶
- p_norm_distance = 1.0¶
- class gfn.containers.replay_buffer.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1, communication_backend='mpi', timing=False, async_score=False, async_comm=False, lazy_sort=False, baseline_filtering=False, scoring_only=False, baseline_refresh_after=10)¶
A replay buffer for storing training containers.
Supports local-only operation and distributed remote buffer communication.
- Features:
Local buffering: Stores Trajectories, Transitions, or StatesContainers up to a fixed capacity.
Prioritized capacity: Optionally keeps only the highest-reward items when the buffer is full.
Prioritized sampling: Optionally samples with probability proportional to reward (softmax over log-rewards).
Remote buffer communication: When
remote_manager_rankis set, periodically sends batched containers to a remoteReplayBufferManagerand receives score dictionaries back.Communication backends: The
communication_backendparameter selects between"torch"(PyTorch distributed / Gloo) and"mpi"(MPI4PY, ~8-12 GB/s vs ~100 MB/s with Gloo).Async scoring: When
async_scoreis enabled, trajectory sends are fire-and-forget; scores are collected lazily on the nextadd()call (1-iteration stale), decoupling training throughput from buffer scoring latency.Timing instrumentation: When
timingis enabled, serialization, send, and receive durations are recorded for profiling.
- Parameters:
env (gfn.env.Env)
capacity (int)
prioritized_capacity (bool)
prioritized_sampling (bool)
remote_manager_rank (int | None)
remote_buffer_freq (int)
communication_backend (str)
timing (bool)
async_score (bool)
async_comm (bool)
lazy_sort (bool)
baseline_filtering (bool)
scoring_only (bool)
baseline_refresh_after (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). Dynamically set based on the type of the first added object.
- __len__()¶
Returns the number of items in the buffer.
- Returns:
The number of items in the buffer (including pending batches).
- Return type:
int
- __repr__()¶
Returns a string representation of the ReplayBuffer.
- Returns:
A string summary of the buffer.
- Return type:
str
- _add_counter = 0¶
- _baseline_kept: int = 0¶
- _baseline_log_reward: float¶
- _baseline_skipped_sends: int = 0¶
- _baseline_total: int = 0¶
- _collect_pending_score()¶
Collect a pending score response from a previous async send.
Returns None if no score is pending (e.g., first iteration).
- Return type:
dict[str, float] | None
- _consecutive_filtered_empty: int = 0¶
- _filter_and_send(container, send_fn)¶
Filter by baseline, prepare for remote, and send.
Returns whatever
send_fnreturns (a score dict for sync sends, None for async sends), or None if baseline filtering drops everything.
- _filter_by_baseline(container)¶
Filter a container to keep only items with log_reward >= baseline.
Returns the (possibly subset) container, or None if every item is below the baseline. After
baseline_refresh_afterconsecutive fully-filtered batches, the next batch bypasses the filter so the worker can receive a fresh baseline.Transitionsis not supported (its log_rewards is per-transition with-inffor non-terminating rows, so per-row filtering would break DB/SubTB).- Parameters:
container (ContainerUnion)
- Return type:
ContainerUnion | None
- _flush_pending()¶
Concatenate all pending batches into
training_container.Called lazily when the accumulated size reaches 2 * capacity, or eagerly by callers that need a consistent view.
Merges all pending batches into a single combined batch first, then extends
training_containeronce to avoid extra copy cost.- Return type:
None
- _is_full = False¶
- _isend_and_defer_score(training_container)¶
Non-blocking send (isend), deferred score: fire-and-forget data, collect score on next add().
The send handle is kept alive in
_send_handleuntil the next call to_wait_previous_send.- Parameters:
training_container (ContainerUnion)
- Return type:
None
- _local_add(training_container)¶
Adds a training object to the local buffer, handling capacity.
Subclasses override this to customize local insertion logic (e.g., diversity filtering). The base class
add()calls this method, then handles remote buffer communication separately.- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- _pending_batches: list[ContainerUnion] = []¶
- _pending_len: int = 0¶
- _pending_score: bool = False¶
- _prepare_for_remote(container)¶
Convert a container to a lightweight form for remote scoring.
When
scoring_onlyis True, extracts terminating states and log-rewards into aStatesContainer.Transitionsis rejected because itslog_rewardsshape does not matchterminating_states(it is per-transition, not per-trajectory). Whenscoring_onlyis False, returns the container unchanged.- Parameters:
container (ContainerUnion)
- Return type:
ContainerUnion
- _recv_score()¶
Receive a score dictionary from the remote manager.
- Return type:
dict[str, float]
- _send_data(training_container)¶
Send a training container to the remote manager.
- Parameters:
training_container (ContainerUnion)
- Return type:
None
- _send_handle: gfn.utils.distributed.AsyncSendHandle | None = None¶
- _send_objs(training_container)¶
Sends a training container to the remote manager (synchronous).
- Parameters:
training_container (ContainerUnion)
- Return type:
dict[str, float]
- _send_objs_async(training_container)¶
Sends a training container without waiting for the score response.
The score will be collected on the next call to
_collect_pending_score.- Parameters:
training_container (ContainerUnion)
- Return type:
None
- _sort_and_truncate(training_container)¶
Sort by log-reward (if prioritized) and truncate to capacity.
- Parameters:
training_container (ContainerUnion)
- Return type:
None
- _update_baseline(score_dict)¶
Extract and store the baseline log-reward from a score response.
Called after receiving a score dict from the buffer manager. Only updates if baseline_filtering is enabled and the score dict contains a
baseline_log_rewardkey.- Parameters:
score_dict (dict[str, float] | None)
- Return type:
None
- _wait_previous_send()¶
Block until the previous non-blocking send has completed.
This is typically near-instantaneous because MPI internally buffers the data, but guarantees the send buffer can be safely reused.
- Return type:
None
- add(training_container)¶
Adds a training container to the buffer.
The type of the training container is dynamically set based on the type of the first added container.
When
async_scoreis enabled, scores are collected lazily: the first call returns None (no pending score yet), and subsequent calls return the score from the previous submission. This decouples training throughput from buffer scoring latency.When
baseline_filteringis enabled, only trajectories with log-reward above the remote buffer’s baseline are sent. If all trajectories in the pending batch are below the baseline, the send is skipped entirely.- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- Return type:
dict[str, float] | None
- async_comm = False¶
- async_score = False¶
- baseline_filtering = False¶
- baseline_refresh_after = 10¶
- capacity = 1000¶
- communication_backend = 'mpi'¶
- property device: torch.device¶
The device on which the buffer’s data is stored.
- Returns:
The device object of the buffer’s contents.
- Return type:
torch.device
- drain_pending_score(timeout_sec=30.0)¶
Drain any outstanding async score before shutdown.
Should be called before sending the EXIT signal when
async_scoreorasync_commis enabled, to avoid leaving the buffer manager with an undelivered response.For
async_commmode this also waits for the outstanding non-blocking send to complete.Uses a timeout to avoid hanging indefinitely if the buffer manager has crashed. Returns None on timeout (score is lost).
- Parameters:
timeout_sec (float)
- Return type:
dict[str, float] | None
- env¶
- initialize(training_container)¶
Initializes the buffer with the type of the first added object.
- Parameters:
training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.
- Return type:
None
- lazy_sort = False¶
- load(path)¶
Loads buffer contents from a
.ptfile saved bysave().- Parameters:
path (str) – File path to the saved buffer.
- pending_container: ContainerUnion | None = None¶
- prioritized_capacity = False¶
- prioritized_sampling = False¶
- remote_buffer_freq = 1¶
- remote_manager_rank = None¶
- sample(n_samples)¶
Samples training objects from the buffer.
- Parameters:
n_samples (int) – The number of items to sample.
- Returns:
A sampled Trajectories, Transitions, or StatesContainer.
- Return type:
ContainerUnion
- save(path)¶
Saves the buffer to a single
.ptfile.- Parameters:
path (str) – File path (e.g.
"replay_buffer.pt").
- scoring_only = False¶
- timing = False¶
- timing_data: dict[str, list[float]]¶
- timing_log()¶
Returns a formatted string of the timing information for the replay buffer.
- Return type:
str
- training_container: ContainerUnion | None = None¶
- class gfn.containers.replay_buffer.TerminatingStateBuffer(env, capacity=1000, communication_backend='mpi', timing=False, **kwargs)¶
Bases:
ReplayBufferA replay buffer for storing terminating states.
- Parameters:
env (gfn.env.Env)
capacity (int)
communication_backend (str)
timing (bool)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (StatesContainer).
- _local_add(training_container)¶
Extracts terminating states and adds them to the local buffer.
Overrides the base class hook so that
add()(which handles remote communication) delegates local insertion here.- Parameters:
training_container (ContainerUnion)
- training_container¶