gfn.containers.trajectories¶
Classes¶
Container for complete trajectories (starting in $s_0$ and ending in $s_f$). |
Functions¶
|
Pads tensor a or b to match the first dimension of the other. |
Module Contents¶
- class gfn.containers.trajectories.Trajectories(env, states=None, actions=None, terminating_idx=None, is_backward=False, log_rewards=None, log_probs=None, estimator_outputs=None)¶
Bases:
gfn.containers.base.ContainerContainer for complete trajectories (starting in \(s_0\) and ending in \(s_f\)).
Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state (\(s_f\) or \(s_0\) depending on the direction of the trajectory), and actions is appended with dummy actions. The terminating_idx tensor represents the time step at which each trajectory ends.
- Parameters:
env (gfn.env.Env)
states (gfn.states.States | None)
actions (gfn.actions.Actions | None)
terminating_idx (torch.Tensor | None)
is_backward (bool)
log_rewards (torch.Tensor | None)
log_probs (torch.Tensor | None)
estimator_outputs (torch.Tensor | None)
- env¶
The environment where the states and actions are defined.
- states¶
States with batch_shape (max_length+1, batch_size).
- actions¶
Actions with batch_shape (max_length, batch_size).
- terminating_idx¶
Tensor of shape (batch_size,) indicating the time step at which each trajectory ends.
- is_backward¶
Whether the trajectories are backward or forward. When not is_backward, the states are ordered from initial to terminal states. When is_backward, the states are ordered from terminal to initial states.
- _log_rewards¶
(Optional) Tensor of shape (batch_size,) containing the log rewards of the trajectories.
- log_probs¶
(Optional) Tensor of shape (max_length, batch_size) indicating the log probabilities of the trajectories’ actions.
- estimator_outputs¶
(Optional) Tensor of shape (max_length, batch_size, …) containing outputs of a function approximator for each step.
- __getitem__(index)¶
Returns a subset of the trajectories along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select trajectories.
- Returns:
A new Trajectories object with the selected trajectories and associated data.
- Return type:
- __len__()¶
Returns the number of trajectories in the container.
- Returns:
The number of trajectories.
- Return type:
int
- __repr__()¶
Returns a string representation of the Trajectories container.
- Returns:
A string summary of the trajectories.
- Return type:
str
- _log_rewards = None¶
- actions¶
- property batch_size: int¶
The number of trajectories in the container.
- Returns:
The number of trajectories.
- Return type:
int
- property device: torch.device¶
The device on which the trajectories are stored.
- Returns:
The device object of the self.states.
- Return type:
torch.device
- env¶
- estimator_outputs = None¶
- extend(other)¶
Extends this Trajectories object with another Trajectories object.
Extends along all attributes in turn (actions, states, terminating_idx, log_probs, log_rewards).
- Parameters:
other (Trajectories) – Another Trajectories to append.
- Return type:
None
- classmethod from_tensordict(env, td)¶
Reconstruct Trajectories from a TensorDict.
- Parameters:
env (gfn.env.Env)
td (gfn.containers.base.TensorDictBase)
- Return type:
- is_backward = False¶
- log_probs = None¶
- property log_rewards: torch.Tensor | None¶
The log rewards for the trajectories.
- Returns:
Log rewards tensor of shape (batch_size,).
- Return type:
torch.Tensor | None
Note
If not provided at initialization, log rewards are computed on demand for terminating states.
- property max_length: int¶
The maximum length of the trajectories in the container.
- Returns:
The maximum trajectory length.
- Return type:
int
- property n_trajectories: int¶
Deprecated alias for
batch_size.- Return type:
int
- reverse_backward_trajectories()¶
Returns a reversed version of the backward trajectories.
- Return type:
- states¶
- terminating_idx¶
- property terminating_states: gfn.states.States¶
The terminating states of the trajectories.
- Returns:
The terminating states.
- Return type:
- to_states_container()¶
Returns a StatesContainer object from the current Trajectories.
- Returns:
A StatesContainer object with the same states, actions, and log_rewards as the current Trajectories.
- Return type:
- to_tensordict()¶
Serialize trajectories into a TensorDict.
- Return type:
gfn.containers.base.TensorDictBase
- to_transitions()¶
Returns a Transitions object from the current Trajectories.
- Returns:
A Transitions object with the same states, actions, and log_rewards as the current Trajectories.
- Return type:
- gfn.containers.trajectories.pad_dim0_if_needed(a, b, value=-float('inf'))¶
Pads tensor a or b to match the first dimension of the other.
- Parameters:
a (torch.Tensor) – First tensor.
b (torch.Tensor) – Second tensor.
value (float) – Value to use for padding.
- Returns:
Tuple of tensors with the same first dimension.
- Return type:
tuple[torch.Tensor, torch.Tensor]