gfn.containers.replay_buffer

Attributes

ContainerUnion

Classes

Container

Base class for protocol classes.

NormBasedDiversePrioritizedReplayBuffer

A replay buffer with diversity-based prioritization.

ReplayBuffer

A replay buffer for storing containers.

TerminatingStateBuffer

A replay buffer for storing terminating states.

Module Contents

class gfn.containers.replay_buffer.Container

Bases: Protocol

Base class for protocol classes.

Protocol classes are defined as:

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).

For example:

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
__getitem__(idx)
__len__()
Return type:

int

extend(other)
property log_rewards: torch.Tensor | None
Return type:

torch.Tensor | None

property terminating_states
gfn.containers.replay_buffer.ContainerUnion
class gfn.containers.replay_buffer.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1)

Bases: ReplayBuffer

A replay buffer with diversity-based prioritization.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • cutoff_distance (float)

  • p_norm_distance (float)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.

prioritized_capacity

Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.

prioritized_sampling

Whether to sample items with probability proportional to their reward.

cutoff_distance

Threshold used to determine whether a new terminating state is different enough from those already in the buffer.

p_norm_distance

p-norm value for distance calculation (used in torch.cdist).

static _diversity_repr(container)

Returns the tensor used for pairwise distance in diversity filtering.

For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.

Parameters:

container (ContainerUnion)

Return type:

torch.Tensor

add(training_container)

Adds a training object to the buffer with diversity-based prioritization.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

cutoff_distance = 0.0
p_norm_distance = 1.0
class gfn.containers.replay_buffer.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1)

A replay buffer for storing containers.

Parameters:
  • env (gfn.env.Env)

  • capacity (int)

  • prioritized_capacity (bool)

  • prioritized_sampling (bool)

  • remote_manager_rank (int | None)

  • remote_buffer_freq (int)

env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.

prioritized_capacity

Whether to use prioritized capacity (keep highest-reward items).

prioritized_sampling

Whether to sample items with probability proportional to their reward.

__len__()

Returns the number of items in the buffer.

Returns:

The number of items in the buffer.

Return type:

int

__repr__()

Returns a string representation of the ReplayBuffer.

Returns:

A string summary of the buffer.

Return type:

str

_add_counter = 0
_add_objs(training_container)

Adds a training object to the buffer, handling the capacity.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

_is_full = False
_send_objs(training_container)

Sends a training container to the remote manager.

Parameters:

training_container (ContainerUnion)

Return type:

dict[str, float]

add(training_container)

Adds a training container to the buffer.

The type of the training container is dynamically set based on the type of the first added container.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

Return type:

dict[str, float] | None

capacity = 1000
property device: torch.device

The device on which the buffer’s data is stored.

Returns:

The device object of the buffer’s contents.

Return type:

torch.device

env
initialize(training_container)

Initializes the buffer with the type of the first added object.

Parameters:

training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.

Return type:

None

load(path)

Loads buffer contents from a .pt file saved by save().

Parameters:

path (str) – File path to the saved buffer.

pending_container: ContainerUnion | None = None
prioritized_capacity = False
prioritized_sampling = False
remote_buffer_freq = 1
remote_manager_rank = None
sample(n_samples)

Samples training objects from the buffer.

Parameters:

n_samples (int) – The number of items to sample.

Returns:

A sampled Trajectories, Transitions, or StatesContainer.

Return type:

ContainerUnion

save(path)

Saves the buffer to a single .pt file.

Parameters:

path (str) – File path (e.g. "replay_buffer.pt").

training_container: ContainerUnion | None = None
class gfn.containers.replay_buffer.TerminatingStateBuffer(env, capacity=1000, **kwargs)

Bases: ReplayBuffer

A replay buffer for storing terminating states.

Parameters:
env

The environment associated with the containers.

capacity

The maximum number of items the buffer can hold.

training_container

The buffer contents (StatesContainer).

add(training_container)

Adds a training container to the buffer.

The type of the training container is dynamically set based on the type of the first added container.

Parameters:

training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.

training_container