gfn.containers¶
Submodules¶
Classes¶
Base class for state containers (states, transitions, or trajectories). |
|
A replay buffer with diversity-based prioritization. |
|
A replay buffer for storing containers. |
|
Container for a batch of states (mainly used for FMGFlowNet). |
|
A replay buffer for storing terminating states. |
|
Container for complete trajectories (starting in $s_0$ and ending in $s_f$). |
|
Container for a batch of transitions. |
Package Contents¶
- class gfn.containers.Container¶
Bases:
abc.ABCBase class for state containers (states, transitions, or trajectories).
- abstract __getitem__(index)¶
Returns a subset of the container based on the provided index.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – An integer, slice, tuple, sequence of indices or booleans, or a torch.Tensor specifying which elements to select.
- Returns:
A new container containing the selected elements and associated data.
- Return type:
- abstract __len__()¶
Returns the number of elements in the container.
- Returns:
The number of elements in the container.
- Return type:
int
- property device: torch.device¶
- Abstractmethod:
- Return type:
torch.device
The device on which the container is stored.
- Returns:
The device on which the container is stored.
- Return type:
torch.device
- abstract extend(other)¶
Extends the current container with elements from another container object.
- Parameters:
other (Container) – The other container whose elements will be added.
- Return type:
None
- classmethod from_tensordict(env, td)¶
- Abstractmethod:
- Parameters:
env (gfn.env.Env)
td (tensordict.base.TensorDictBase)
- Return type:
Reconstruct a container from a TensorDict.
- Parameters:
env (gfn.env.Env) – The environment needed to reconstruct States/Actions.
td (tensordict.base.TensorDictBase) – The TensorDict produced by
to_tensordict().
- Returns:
A new container instance.
- Return type:
- property has_log_probs: bool¶
Whether the container has log probabilities.
- Returns:
True if log probabilities are present and non-empty, False otherwise.
- Return type:
bool
- classmethod load(env, path)¶
Loads a container from a
.ptfile saved bysave().- Parameters:
env (gfn.env.Env) – The environment needed to reconstruct States/Actions.
path (str) – File path to the saved container.
- Returns:
A new container instance.
- Return type:
- property log_rewards: torch.Tensor¶
- Abstractmethod:
- Return type:
torch.Tensor
The log rewards associated with the container.
- Returns:
The log rewards tensor.
- Return type:
torch.Tensor
- sample(n_samples)¶
Randomly samples a subset of elements from the container.
- Parameters:
n_samples (int) – The number of elements to sample.
- Returns:
A new container with the sampled elements.
- Return type:
- save(path)¶
Saves the container to a single
.ptfile.- Parameters:
path (str) – File path (e.g.
"trajectories.pt").- Return type:
None
- property terminating_states: gfn.states.States¶
- Abstractmethod:
- Return type:
The last (terminating) states of the container.
- Returns:
The terminating states.
- Return type:
- abstract to_tensordict()¶
Serialize the container’s data into a TensorDict.
- Returns:
A TensorDict containing all tensor data and scalar metadata. The
envreference is not included; it must be supplied when reconstructing viafrom_tensordict().- Return type:
tensordict.base.TensorDictBase
- class gfn.containers.NormBasedDiversePrioritizedReplayBuffer(env, capacity=1000, cutoff_distance=0.0, p_norm_distance=1.0, remote_manager_rank=None, remote_buffer_freq=1)¶
Bases:
ReplayBufferA replay buffer with diversity-based prioritization.
- Parameters:
env (gfn.env.Env)
capacity (int)
cutoff_distance (float)
p_norm_distance (float)
remote_manager_rank (int | None)
remote_buffer_freq (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.
- prioritized_capacity¶
Whether to use prioritized capacity (keep highest-reward items). This is set to True by default.
- prioritized_sampling¶
Whether to sample items with probability proportional to their reward.
- cutoff_distance¶
Threshold used to determine whether a new terminating state is different enough from those already in the buffer.
- p_norm_distance¶
p-norm value for distance calculation (used in torch.cdist).
- static _diversity_repr(container)¶
Returns the tensor used for pairwise distance in diversity filtering.
For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct.
- Parameters:
container (ContainerUnion)
- Return type:
torch.Tensor
- add(training_container)¶
Adds a training object to the buffer with diversity-based prioritization.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- cutoff_distance = 0.0¶
- p_norm_distance = 1.0¶
- class gfn.containers.ReplayBuffer(env, capacity=1000, prioritized_capacity=False, prioritized_sampling=False, remote_manager_rank=None, remote_buffer_freq=1)¶
A replay buffer for storing containers.
- Parameters:
env (gfn.env.Env)
capacity (int)
prioritized_capacity (bool)
prioritized_sampling (bool)
remote_manager_rank (int | None)
remote_buffer_freq (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object.
- prioritized_capacity¶
Whether to use prioritized capacity (keep highest-reward items).
- prioritized_sampling¶
Whether to sample items with probability proportional to their reward.
- __len__()¶
Returns the number of items in the buffer.
- Returns:
The number of items in the buffer.
- Return type:
int
- __repr__()¶
Returns a string representation of the ReplayBuffer.
- Returns:
A string summary of the buffer.
- Return type:
str
- _add_counter = 0¶
- _add_objs(training_container)¶
Adds a training object to the buffer, handling the capacity.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- _is_full = False¶
- _send_objs(training_container)¶
Sends a training container to the remote manager.
- Parameters:
training_container (ContainerUnion)
- Return type:
dict[str, float]
- add(training_container)¶
Adds a training container to the buffer.
The type of the training container is dynamically set based on the type of the first added container.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- Return type:
dict[str, float] | None
- capacity = 1000¶
- property device: torch.device¶
The device on which the buffer’s data is stored.
- Returns:
The device object of the buffer’s contents.
- Return type:
torch.device
- env¶
- initialize(training_container)¶
Initializes the buffer with the type of the first added object.
- Parameters:
training_container (ContainerUnion) – The initial Trajectories, Transitions, or StatesContainer object to set the buffer type.
- Return type:
None
- load(path)¶
Loads buffer contents from a
.ptfile saved bysave().- Parameters:
path (str) – File path to the saved buffer.
- pending_container: ContainerUnion | None = None¶
- prioritized_capacity = False¶
- prioritized_sampling = False¶
- remote_buffer_freq = 1¶
- remote_manager_rank = None¶
- sample(n_samples)¶
Samples training objects from the buffer.
- Parameters:
n_samples (int) – The number of items to sample.
- Returns:
A sampled Trajectories, Transitions, or StatesContainer.
- Return type:
ContainerUnion
- save(path)¶
Saves the buffer to a single
.ptfile.- Parameters:
path (str) – File path (e.g.
"replay_buffer.pt").
- training_container: ContainerUnion | None = None¶
- class gfn.containers.StatesContainer(env, states=None, is_terminating=None, log_rewards=None)¶
Bases:
gfn.containers.base.Container,Generic[StateType]Container for a batch of states (mainly used for FMGFlowNet).
This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms.
- Parameters:
env (gfn.env.Env)
states (StateType | None)
is_terminating (torch.Tensor | None)
log_rewards (torch.Tensor | None)
- env¶
The environment where the states are defined.
- states¶
States with batch_shape (n_states,).
- is_terminating¶
Boolean tensor of shape (n_states,) indicating which states are terminating.
- _log_rewards¶
(Optional) Tensor of shape (n_states,) containing the log rewards for terminating states.
- __getitem__(index)¶
Returns a subset of the states along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.
- Returns:
A new StatesContainer with the selected states and associated data.
- Return type:
StatesContainer[StateType]
- __len__()¶
Returns the number of states in the container.
- Returns:
The number of states.
- Return type:
int
- __repr__()¶
Returns a string representation of the StatesContainer.
- Returns:
A string summary of the container.
- Return type:
str
- _log_rewards = None¶
- property device: torch.device¶
The device on which the states are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- extend(other)¶
Extends this container with another StatesContainer object.
- Parameters:
append. (Another StatesContainer to)
other (StatesContainer[StateType])
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct a StatesContainer from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
StatesContainer[StateType]
- property intermediary_states: StateType¶
The intermediary states (not initial states) of the StatesContainer.
- Returns:
The intermediary states.
- Return type:
StateType
- is_terminating¶
- property log_rewards: torch.Tensor¶
The log rewards for all states.
- Returns:
- Log rewards tensor of shape (len(self.states),). Intermediate states have
value -inf.
- Return type:
torch.Tensor
Note
If not provided at initialization, log rewards are computed on demand for terminating states.
- states¶
- property terminating_log_rewards: torch.Tensor¶
The log rewards for terminating states only.
- Returns:
The log rewards for terminating states.
- Return type:
torch.Tensor
- property terminating_states: StateType¶
The last (terminating) states of the StatesContainer.
- Returns:
The terminating states.
- Return type:
StateType
- to_tensordict()¶
Serialize the states container into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase
- class gfn.containers.TerminatingStateBuffer(env, capacity=1000, **kwargs)¶
Bases:
ReplayBufferA replay buffer for storing terminating states.
- Parameters:
env (gfn.env.Env)
capacity (int)
- env¶
The environment associated with the containers.
- capacity¶
The maximum number of items the buffer can hold.
- training_container¶
The buffer contents (StatesContainer).
- add(training_container)¶
Adds a training container to the buffer.
The type of the training container is dynamically set based on the type of the first added container.
- Parameters:
training_container (ContainerUnion) – The Trajectories, Transitions, or StatesContainer object to add.
- training_container¶
- class gfn.containers.Trajectories(env, states=None, actions=None, terminating_idx=None, is_backward=False, log_rewards=None, log_probs=None, estimator_outputs=None)¶
Bases:
gfn.containers.base.ContainerContainer for complete trajectories (starting in \(s_0\) and ending in \(s_f\)).
Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state (\(s_f\) or \(s_0\) depending on the direction of the trajectory), and actions is appended with dummy actions. The terminating_idx tensor represents the time step at which each trajectory ends.
- Parameters:
env (gfn.env.Env)
states (gfn.states.States | None)
actions (gfn.actions.Actions | None)
terminating_idx (torch.Tensor | None)
is_backward (bool)
log_rewards (torch.Tensor | None)
log_probs (torch.Tensor | None)
estimator_outputs (torch.Tensor | None)
- env¶
The environment where the states and actions are defined.
- states¶
States with batch_shape (max_length+1, batch_size).
- actions¶
Actions with batch_shape (max_length, batch_size).
- terminating_idx¶
Tensor of shape (batch_size,) indicating the time step at which each trajectory ends.
- is_backward¶
Whether the trajectories are backward or forward. When not is_backward, the states are ordered from initial to terminal states. When is_backward, the states are ordered from terminal to initial states.
- _log_rewards¶
(Optional) Tensor of shape (batch_size,) containing the log rewards of the trajectories.
- log_probs¶
(Optional) Tensor of shape (max_length, batch_size) indicating the log probabilities of the trajectories’ actions.
- estimator_outputs¶
(Optional) Tensor of shape (max_length, batch_size, …) containing outputs of a function approximator for each step.
- __getitem__(index)¶
Returns a subset of the trajectories along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select trajectories.
- Returns:
A new Trajectories object with the selected trajectories and associated data.
- Return type:
- __len__()¶
Returns the number of trajectories in the container.
- Returns:
The number of trajectories.
- Return type:
int
- __repr__()¶
Returns a string representation of the Trajectories container.
- Returns:
A string summary of the trajectories.
- Return type:
str
- _log_rewards = None¶
- actions¶
- property batch_size: int¶
The number of trajectories in the container.
- Returns:
The number of trajectories.
- Return type:
int
- property device: torch.device¶
The device on which the trajectories are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- estimator_outputs = None¶
- extend(other)¶
Extends this Trajectories object with another Trajectories object.
Extends along all attributes in turn (actions, states, terminating_idx, log_probs, log_rewards).
- Parameters:
other (Trajectories) – Another Trajectories to append.
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct Trajectories from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
- is_backward = False¶
- log_probs = None¶
- property log_rewards: torch.Tensor | None¶
The log rewards for the trajectories.
- Returns:
Log rewards tensor of shape (batch_size,).
- Return type:
torch.Tensor | None
Note
If not provided at initialization, log rewards are computed on demand for terminating states.
- property max_length: int¶
The maximum length of the trajectories in the container.
- Returns:
The maximum trajectory length.
- Return type:
int
- property n_trajectories: int¶
Deprecated alias for
batch_size.- Return type:
int
- reverse_backward_trajectories()¶
Returns a reversed version of the backward trajectories.
- Return type:
- states¶
- terminating_idx¶
- property terminating_states: gfn.states.States¶
The terminating states of the trajectories.
- Returns:
The terminating states.
- Return type:
- to_states_container()¶
Returns a StatesContainer object from the current Trajectories.
- Returns:
A StatesContainer object with the same states, actions, and log_rewards as the current Trajectories.
- Return type:
- to_tensordict()¶
Serialize trajectories into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase
- to_transitions()¶
Returns a Transitions object from the current Trajectories.
- Returns:
A Transitions object with the same states, actions, and log_rewards as the current Trajectories.
- Return type:
- class gfn.containers.Transitions(env, states=None, actions=None, is_terminating=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)¶
Bases:
gfn.containers.base.ContainerContainer for a batch of transitions.
This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties.
- Parameters:
env (gfn.env.Env)
states (gfn.states.States | None)
actions (gfn.actions.Actions | None)
is_terminating (torch.Tensor | None)
next_states (gfn.states.States | None)
is_backward (bool)
log_rewards (torch.Tensor | None)
log_probs (torch.Tensor | None)
- env¶
The environment where the states and actions are defined.
- states¶
States with batch_shape (n_transitions,).
- actions¶
Actions with batch_shape (n_transitions,). The actions make the transitions from the states to the next_states.
- is_terminating¶
Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action.
- next_states¶
States with batch_shape (n_transitions,).
- is_backward¶
Whether the transitions are backward transitions. When not is_backward, the states are the parents of the transitions and the next_states are the children. When is_backward, the states are the children of the transitions and the next_states are the parents.
- _log_rewards¶
(Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions.
- log_probs¶
(Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions.
- __getitem__(index)¶
Returns a subset of the transitions along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select transitions.
- Returns:
A new Transitions object with the selected transitions and associated data.
- Return type:
- __len__()¶
Returns the number of transitions in the container.
- Returns:
The number of transitions.
- Return type:
int
- __repr__()¶
Returns a string representation of the Transitions container.
- Returns:
A string summary of the transitions.
- _log_rewards = None¶
- actions¶
- property all_log_rewards: torch.Tensor¶
A helper method to compute the log rewards for all transitions
This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.
- Returns:
Log rewards tensor of shape (n_transitions, 2) for the transitions.
- Return type:
torch.Tensor
- property device: torch.device¶
The device on which the transitions are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- extend(other)¶
Extends this Transitions object with another Transitions object.
- Parameters:
append. (Another Transitions object to)
other (Transitions)
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct Transitions from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
- is_backward = False¶
- is_terminating¶
- log_probs = None¶
- property log_rewards: torch.Tensor | None¶
The log rewards for the transitions.
- Returns:
Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf.
- Return type:
torch.Tensor | None
Note
If not provided at initialization, log rewards are computed on demand for terminating transitions.
- property n_transitions: int¶
The number of transitions in the container.
- Returns:
The number of transitions.
- Return type:
int
- next_states¶
- states¶
- property terminating_states: gfn.states.States¶
The terminating states of the transitions.
- Returns:
The terminating states.
- Return type:
- to_tensordict()¶
Serialize transitions into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase