gfn.containers.transitions¶
Classes¶
Container for a batch of transitions. |
Module Contents¶
- class gfn.containers.transitions.Transitions(env, states=None, actions=None, is_terminating=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)¶
Bases:
gfn.containers.base.ContainerContainer for a batch of transitions.
This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties.
- Parameters:
env (gfn.env.Env)
states (gfn.states.States | None)
actions (gfn.actions.Actions | None)
is_terminating (torch.Tensor | None)
next_states (gfn.states.States | None)
is_backward (bool)
log_rewards (torch.Tensor | None)
log_probs (torch.Tensor | None)
- env¶
The environment where the states and actions are defined.
- states¶
States with batch_shape (n_transitions,).
- actions¶
Actions with batch_shape (n_transitions,). The actions make the transitions from the states to the next_states.
- is_terminating¶
Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action.
- next_states¶
States with batch_shape (n_transitions,).
- is_backward¶
Whether the transitions are backward transitions. When not is_backward, the states are the parents of the transitions and the next_states are the children. When is_backward, the states are the children of the transitions and the next_states are the parents.
- _log_rewards¶
(Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions.
- log_probs¶
(Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions.
- __getitem__(index)¶
Returns a subset of the transitions along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select transitions.
- Returns:
A new Transitions object with the selected transitions and associated data.
- Return type:
- __len__()¶
Returns the number of transitions in the container.
- Returns:
The number of transitions.
- Return type:
int
- __repr__()¶
Returns a string representation of the Transitions container.
- Returns:
A string summary of the transitions.
- _log_rewards = None¶
- actions¶
- property all_log_rewards: torch.Tensor¶
A helper method to compute the log rewards for all transitions
This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.
- Returns:
Log rewards tensor of shape (n_transitions, 2) for the transitions.
- Return type:
torch.Tensor
- property device: torch.device¶
The device on which the transitions are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- extend(other)¶
Extends this Transitions object with another Transitions object.
- Parameters:
append. (Another Transitions object to)
other (Transitions)
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct Transitions from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
- is_backward = False¶
- is_terminating¶
- log_probs = None¶
- property log_rewards: torch.Tensor | None¶
The log rewards for the transitions.
- Returns:
Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf.
- Return type:
torch.Tensor | None
Note
If not provided at initialization, log rewards are computed on demand for terminating transitions.
- property n_transitions: int¶
The number of transitions in the container.
- Returns:
The number of transitions.
- Return type:
int
- next_states¶
- states¶
- property terminating_states: gfn.states.States¶
The terminating states of the transitions.
- Returns:
The terminating states.
- Return type:
- to_tensordict()¶
Serialize transitions into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase