gfn.containers.base¶
Classes¶
Base class for state containers (states, transitions, or trajectories). |
Module Contents¶
- class gfn.containers.base.Container¶
Bases:
abc.ABCBase class for state containers (states, transitions, or trajectories).
- abstract __getitem__(index)¶
Returns a subset of the container based on the provided index.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – An integer, slice, tuple, sequence of indices or booleans, or a torch.Tensor specifying which elements to select.
- Returns:
A new container containing the selected elements and associated data.
- Return type:
- abstract __len__()¶
Returns the number of elements in the container.
- Returns:
The number of elements in the container.
- Return type:
int
- property device: torch.device¶
- Abstractmethod:
- Return type:
torch.device
The device on which the container is stored.
- Returns:
The device on which the container is stored.
- Return type:
torch.device
- abstract extend(other)¶
Extends the current container with elements from another container object.
- Parameters:
other (Container) – The other container whose elements will be added.
- Return type:
None
- classmethod from_tensordict(env, td)¶
- Abstractmethod:
- Parameters:
env (gfn.env.Env)
td (tensordict.base.TensorDictBase)
- Return type:
Reconstruct a container from a TensorDict.
- Parameters:
env (gfn.env.Env) – The environment needed to reconstruct States/Actions.
td (tensordict.base.TensorDictBase) – The TensorDict produced by
to_tensordict().
- Returns:
A new container instance.
- Return type:
- property has_log_probs: bool¶
Whether the container has log probabilities.
- Returns:
True if log probabilities are present and non-empty, False otherwise.
- Return type:
bool
- classmethod load(env, path)¶
Loads a container from a
.ptfile saved bysave().- Parameters:
env (gfn.env.Env) – The environment needed to reconstruct States/Actions.
path (str) – File path to the saved container.
- Returns:
A new container instance.
- Return type:
- property log_rewards: torch.Tensor¶
- Abstractmethod:
- Return type:
torch.Tensor
The log rewards associated with the container.
- Returns:
The log rewards tensor.
- Return type:
torch.Tensor
- sample(n_samples)¶
Randomly samples a subset of elements from the container.
- Parameters:
n_samples (int) – The number of elements to sample.
- Returns:
A new container with the sampled elements.
- Return type:
- save(path)¶
Saves the container to a single
.ptfile.- Parameters:
path (str) – File path (e.g.
"trajectories.pt").- Return type:
None
- property terminating_states: gfn.states.States¶
- Abstractmethod:
- Return type:
The last (terminating) states of the container.
- Returns:
The terminating states.
- Return type:
- abstract to_tensordict()¶
Serialize the container’s data into a TensorDict.
- Returns:
A TensorDict containing all tensor data and scalar metadata. The
envreference is not included; it must be supplied when reconstructing viafrom_tensordict().- Return type:
tensordict.base.TensorDictBase