gfn.containers.states_container¶
Attributes¶
Classes¶
Container for a batch of states (mainly used for FMGFlowNet). |
Module Contents¶
- gfn.containers.states_container.StateType¶
- class gfn.containers.states_container.StatesContainer(env, states=None, is_terminating=None, log_rewards=None)¶
Bases:
gfn.containers.base.Container,Generic[StateType]Container for a batch of states (mainly used for FMGFlowNet).
This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms.
- Parameters:
env (gfn.env.Env)
states (StateType | None)
is_terminating (torch.Tensor | None)
log_rewards (torch.Tensor | None)
- env¶
The environment where the states are defined.
- states¶
States with batch_shape (n_states,).
- is_terminating¶
Boolean tensor of shape (n_states,) indicating which states are terminating.
- _log_rewards¶
(Optional) Tensor of shape (n_states,) containing the log rewards for terminating states.
- __getitem__(index)¶
Returns a subset of the states along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.
- Returns:
A new StatesContainer with the selected states and associated data.
- Return type:
StatesContainer[StateType]
- __len__()¶
Returns the number of states in the container.
- Returns:
The number of states.
- Return type:
int
- __repr__()¶
Returns a string representation of the StatesContainer.
- Returns:
A string summary of the container.
- Return type:
str
- _log_rewards = None¶
- property device: torch.device¶
The device on which the states are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- extend(other)¶
Extends this container with another StatesContainer object.
- Parameters:
append. (Another StatesContainer to)
other (StatesContainer[StateType])
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct a StatesContainer from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
StatesContainer[StateType]
- property intermediary_states: StateType¶
The intermediary states (not initial states) of the StatesContainer.
- Returns:
The intermediary states.
- Return type:
StateType
- is_terminating¶
- property log_rewards: torch.Tensor¶
The log rewards for all states.
- Returns:
- Log rewards tensor of shape (len(self.states),). Intermediate states have
value -inf.
- Return type:
torch.Tensor
Note
If not provided at initialization, log rewards are computed on demand for terminating states.
- states¶
- property terminating_log_rewards: torch.Tensor¶
The log rewards for terminating states only.
- Returns:
The log rewards for terminating states.
- Return type:
torch.Tensor
- property terminating_states: StateType¶
The last (terminating) states of the StatesContainer.
- Returns:
The terminating states.
- Return type:
StateType
- to_tensordict()¶
Serialize the states container into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase