gfn.containers.base

Classes

Container

Base class for state containers (states, transitions, or trajectories).

Module Contents

class gfn.containers.base.Container

Bases: abc.ABC

Base class for state containers (states, transitions, or trajectories).

abstract __getitem__(index)

Returns a subset of the container based on the provided index.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – An integer, slice, tuple, sequence of indices or booleans, or a torch.Tensor specifying which elements to select.

Returns:

A new container containing the selected elements and associated data.

Return type:

Container

abstract __len__()

Returns the number of elements in the container.

Returns:

The number of elements in the container.

Return type:

int

property device: torch.device
Abstractmethod:

Return type:

torch.device

The device on which the container is stored.

Returns:

The device on which the container is stored.

Return type:

torch.device

abstract extend(other)

Extends the current container with elements from another container object.

Parameters:

other (Container) – The other container whose elements will be added.

Return type:

None

classmethod from_tensordict(env, td)
Abstractmethod:

Parameters:
  • env (gfn.env.Env)

  • td (tensordict.base.TensorDictBase)

Return type:

Container

Reconstruct a container from a TensorDict.

Parameters:
  • env (gfn.env.Env) – The environment needed to reconstruct States/Actions.

  • td (tensordict.base.TensorDictBase) – The TensorDict produced by to_tensordict().

Returns:

A new container instance.

Return type:

Container

property has_log_probs: bool

Whether the container has log probabilities.

Returns:

True if log probabilities are present and non-empty, False otherwise.

Return type:

bool

classmethod load(env, path)

Loads a container from a .pt file saved by save().

Parameters:
  • env (gfn.env.Env) – The environment needed to reconstruct States/Actions.

  • path (str) – File path to the saved container.

Returns:

A new container instance.

Return type:

Container

property log_rewards: torch.Tensor
Abstractmethod:

Return type:

torch.Tensor

The log rewards associated with the container.

Returns:

The log rewards tensor.

Return type:

torch.Tensor

sample(n_samples)

Randomly samples a subset of elements from the container.

Parameters:

n_samples (int) – The number of elements to sample.

Returns:

A new container with the sampled elements.

Return type:

Container

save(path)

Saves the container to a single .pt file.

Parameters:

path (str) – File path (e.g. "trajectories.pt").

Return type:

None

property terminating_states: gfn.states.States
Abstractmethod:

Return type:

gfn.states.States

The last (terminating) states of the container.

Returns:

The terminating states.

Return type:

gfn.states.States

abstract to_tensordict()

Serialize the container’s data into a TensorDict.

Returns:

A TensorDict containing all tensor data and scalar metadata. The env reference is not included; it must be supplied when reconstructing via from_tensordict().

Return type:

tensordict.base.TensorDictBase