gfn.containers.transitions

Classes

Transitions

Container for a batch of transitions.

Module Contents

class gfn.containers.transitions.Transitions(env, states=None, actions=None, is_terminating=None, next_states=None, is_backward=False, log_rewards=None, log_probs=None)

Bases: gfn.containers.base.Container

Container for a batch of transitions.

This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties.

Parameters:
env

The environment where the states and actions are defined.

states

States with batch_shape (n_transitions,).

actions

Actions with batch_shape (n_transitions,). The actions make the transitions from the states to the next_states.

is_terminating

Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action.

next_states

States with batch_shape (n_transitions,).

is_backward

Whether the transitions are backward transitions. When not is_backward, the states are the parents of the transitions and the next_states are the children. When is_backward, the states are the children of the transitions and the next_states are the parents.

_log_rewards

(Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions.

log_probs

(Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions.

__getitem__(index)

Returns a subset of the transitions along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select transitions.

Returns:

A new Transitions object with the selected transitions and associated data.

Return type:

Transitions

__len__()

Returns the number of transitions in the container.

Returns:

The number of transitions.

Return type:

int

__repr__()

Returns a string representation of the Transitions container.

Returns:

A string summary of the transitions.

_log_rewards = None
actions
property all_log_rewards: torch.Tensor

A helper method to compute the log rewards for all transitions

This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss.

Returns:

Log rewards tensor of shape (n_transitions, 2) for the transitions.

Return type:

torch.Tensor

property device: torch.device

The device on which the transitions are stored.

Returns:

The device object of the self.states.

Return type:

torch.device

env
extend(other)

Extends this Transitions object with another Transitions object.

Parameters:
  • append. (Another Transitions object to)

  • other (Transitions)

Return type:

None

classmethod from_tensordict(env, td)

Reconstruct Transitions from a TensorDict.

Parameters:
  • env (gfn.env.Env)

  • td (gfn.containers.base.TensorDictBase)

Return type:

Transitions

is_backward = False
is_terminating
log_probs = None
property log_rewards: torch.Tensor | None

The log rewards for the transitions.

Returns:

Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf.

Return type:

torch.Tensor | None

Note

If not provided at initialization, log rewards are computed on demand for terminating transitions.

property n_transitions: int

The number of transitions in the container.

Returns:

The number of transitions.

Return type:

int

next_states
states
property terminating_states: gfn.states.States

The terminating states of the transitions.

Returns:

The terminating states.

Return type:

gfn.states.States

to_tensordict()

Serialize transitions into a TensorDict.

Return type:

gfn.containers.base.TensorDictBase