gfn.containers.replay_buffer¶
Attributes¶
Classes¶
Base class for protocol classes. |
|
A replay buffer with diversity-based prioritization. |
|
A replay buffer for storing containers. |
|
A replay buffer for storing terminating states. |
Module Contents¶
- class gfn.containers.replay_buffer.Container¶
Bases:
ProtocolBase class for protocol classes.
Protocol classes are defined as:
class Proto(Protocol): def meth(self) -> int: ...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing).
For example:
class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:
class GenProto(Protocol[T]): def meth(self) -> T: ...
- __getitem__(idx)¶
- __len__()¶
- Return type:
int
- extend(other)¶
- property log_rewards: torch.Tensor | None¶
- Return type:
torch.Tensor | None
- property terminating_states¶
- gfn.containers.replay_buffer.ContainerUnion¶
- class gfn.containers.replay_buffer.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1)¶
Bases:
ReplayBufferA replay buffer with diversity-based prioritization.
- Parameters:
env (gfn.env.Env)
capacity (int)
cutoff_distance (float)
p_norm_distance (float)
remote_manager_rank (int | None)
remote_buffer_freq (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.
- prioritized_capacity¶
Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.
- prioritized_sampling¶
Whether to sample items with probability proportional to their reward.
- cutoff_distance¶
Threshold used to determine whether a new terminating state is different enough from those already in the buffer.
- p_norm_distance¶
p-norm value for distance calculation (used in torch.cdist).
- static _diversity_repr(container)¶
Returns the tensor used for pairwise distance in diversity filtering.
For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.
- Parameters:
container (ContainerUnion)
- Return type:
torch.Tensor
- add(training_container)¶
Adds a training object to the buffer with diversity-based prioritization.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- cutoff_distance = 0.0¶
- p_norm_distance = 1.0¶
- class gfn.containers.replay_buffer.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1)¶
A replay buffer for storing containers.
- Parameters:
env (gfn.env.Env)
capacity (int)
prioritized_capacity (bool)
prioritized_sampling (bool)
remote_manager_rank (int | None)
remote_buffer_freq (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.
- prioritized_capacity¶
Whether to use prioritized capacity (keep highest-reward items).
- prioritized_sampling¶
Whether to sample items with probability proportional to their reward.
- __len__()¶
Returns the number of items in the buffer.
- Returns:
The number of items in the buffer.
- Return type:
int
- __repr__()¶
Returns a string representation of the ReplayBuffer.
- Returns:
A string summary of the buffer.
- Return type:
str
- _add_counter = 0¶
- _add_objs(training_container)¶
Adds a training object to the buffer, handling the capacity.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- _is_full = False¶
- _send_objs(training_container)¶
Sends a training container to the remote manager.
- Parameters:
training_container (ContainerUnion)
- Return type:
dict[str, float]
- add(training_container)¶
Adds a training container to the buffer.
The type of the training container is dynamically set based on the type of the first added container.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- Return type:
dict[str, float] | None
- capacity = 1000¶
- property device: torch.device¶
The device on which the buffer’s data is stored.
- Returns:
The device object of the buffer’s contents.
- Return type:
torch.device
- env¶
- initialize(training_container)¶
Initializes the buffer with the type of the first added object.
- Parameters:
training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.
- Return type:
None
- load(path)¶
Loads buffer contents from a
.ptfile saved bysave().- Parameters:
path (str) – File path to the saved buffer.
- pending_container: ContainerUnion | None = None¶
- prioritized_capacity = False¶
- prioritized_sampling = False¶
- remote_buffer_freq = 1¶
- remote_manager_rank = None¶
- sample(n_samples)¶
Samples training objects from the buffer.
- Parameters:
n_samples (int) – The number of items to sample.
- Returns:
A sampled Trajectories, Transitions, or StatesContainer.
- Return type:
ContainerUnion
- save(path)¶
Saves the buffer to a single
.ptfile.- Parameters:
path (str) – File path (e.g.
"replay_buffer.pt").
- training_container: ContainerUnion | None = None¶
- class gfn.containers.replay_buffer.TerminatingStateBuffer(env, capacity=1000, **kwargs)¶
Bases:
ReplayBufferA replay buffer for storing terminating states.
- Parameters:
env (gfn.env.Env)
capacity (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (StatesContainer).
- add(training_container)¶
Adds a training container to the buffer.
The type of the training container is dynamically set based on the type of the first added container.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- training_container¶