gfn.containers.states_container

Attributes

StateType

Classes

StatesContainer

Container for a batch of states (mainly used for FMGFlowNet).

Module Contents

gfn.containers.states_container.StateType
class gfn.containers.states_container.StatesContainer(env, states=None, is_terminating=None, log_rewards=None)

Bases: gfn.containers.base.Container, Generic[StateType]

Container for a batch of states (mainly used for FMGFlowNet).

This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms.

Parameters:
  • env (gfn.env.Env)

  • states (StateType | None)

  • is_terminating (torch.Tensor | None)

  • log_rewards (torch.Tensor | None)

env

The environment where the states are defined.

states

States with batch_shape (n_states,).

is_terminating

Boolean tensor of shape (n_states,) indicating which states are terminating.

_log_rewards

(Optional) Tensor of shape (n_states,) containing the log rewards for terminating states.

__getitem__(index)

Returns a subset of the states along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select states.

Returns:

A new StatesContainer with the selected states and associated data.

Return type:

StatesContainer[StateType]

__len__()

Returns the number of states in the container.

Returns:

The number of states.

Return type:

int

__repr__()

Returns a string representation of the StatesContainer.

Returns:

A string summary of the container.

Return type:

str

_log_rewards = None
property device: torch.device

The device on which the states are stored.

Returns:

The device object of the self.states.

Return type:

torch.device

env
extend(other)

Extends this container with another StatesContainer object.

Parameters:
Return type:

None

classmethod from_tensordict(env, td)

Reconstruct a StatesContainer from a TensorDict.

Parameters:
  • env (gfn.env.Env)

  • td (gfn.containers.base.TensorDictBase)

Return type:

StatesContainer[StateType]

property intermediary_states: StateType

The intermediary states (not initial states) of the StatesContainer.

Returns:

The intermediary states.

Return type:

StateType

is_terminating
property log_rewards: torch.Tensor

The log rewards for all states.

Returns:

Log rewards tensor of shape (len(self.states),). Intermediate states have

value -inf.

Return type:

torch.Tensor

Note

If not provided at initialization, log rewards are computed on demand for terminating states.

states
property terminating_log_rewards: torch.Tensor

The log rewards for terminating states only.

Returns:

The log rewards for terminating states.

Return type:

torch.Tensor

property terminating_states: StateType

The last (terminating) states of the StatesContainer.

Returns:

The terminating states.

Return type:

StateType

to_tensordict()

Serialize the states container into a TensorDict.

Return type:

gfn.containers.base.TensorDictBase