gfn.actions

Classes

Actions

Base class for actions, representing edges in the DAG of a GFlowNet.

GraphActionType

Enum where members are also (and must be) ints

GraphActions

Actions for graph-based environments.

Module Contents

class gfn.actions.Actions(tensor, debug=False)

Bases: abc.ABC

Base class for actions, representing edges in the DAG of a GFlowNet.

Each environment needs to define a subclass of Actions to represent its specific action space.

Two useful subclasses of Actions are provided: - DiscreteActions for discrete environments, which represents actions as a tensor

of shape (*batch_shape, *action_shape).

  • GraphActions for graph-based environments, which represents actions as a tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components.

Parameters:
  • tensor (torch.Tensor)

  • debug (bool)

tensor

Tensor of shape (*batch_shape, *action_shape) representing a batch of actions.

action_shape

Class variable, a tuple defining the shape of a single action.

dummy_action

Class variable, a tensor of shape (*action_shape,) representing the dummy action for padding shorter trajectories.

exit_action

Class variable, a tensor of shape (*action_shape,) representing the action to transition to the sink state.

__getitem__(index)

Returns a subset of the actions along the batch dimension.

Parameters:

index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select actions.

Returns:

A new Actions object with the selected actions.

Return type:

Actions

__len__()

Returns the number of actions in the batch.

Returns:

The number of actions.

Return type:

int

__repr__()

Returns a string representation of the Actions object.

Returns:

A string summary of the Actions object.

__setitem__(index, actions)

Sets particular actions of the batch to a new Actions object.

Parameters:
  • index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to set.

  • actions (Actions) – Actions object containing the new actions.

Return type:

None

_compare(other)

Compares the actions to a tensor of actions.

Parameters:

other (torch.Tensor) – Tensor of actions to compare, with shape (*batch_shape, *action_shape).

Returns:

A boolean tensor of shape (*batch_shape,) indicating whether the actions are equal.

Return type:

torch.Tensor

action_shape: ClassVar[tuple[int, Ellipsis]]
property batch_shape: tuple[int, Ellipsis]

The batch shape of the actions.

Returns:

The batch shape as a tuple.

Return type:

tuple[int, Ellipsis]

clone()

Returns a clone of the Actions object.

Returns:

A new Actions object with the same tensor.

Return type:

Actions

debug = False
property device: torch.device

The device on which the actions are stored.

Returns:

The device of the underlying tensor.

Return type:

torch.device

dummy_action: ClassVar[torch.Tensor]
exit_action: ClassVar[torch.Tensor]
extend(other)

Concatenates another Actions object along the final batch dimension.

Both Actions objects must have the same number of batch dimensions, which should be 1 or 2.

Parameters:

other (Actions) – Actions object to be concatenated to the current Actions object.

Return type:

None

extend_with_dummy_actions(required_first_dim)

Extends an Actions instance along the first dimension with dummy actions.

The Actions instance batch_shape must be 2-dimensional. This is used to pad actions in a batch of trajectories to a common length.

Parameters:

required_first_dim (int) – The target size of the first dimension post expansion.

Return type:

None

property is_dummy: torch.Tensor

Returns a boolean tensor indicating whether the actions are dummy actions.

Returns:

A boolean tensor of shape (*batch_shape,) that is True for dummy actions.

Return type:

torch.Tensor

property is_exit: torch.Tensor

Returns a boolean tensor indicating whether the actions are exit actions.

Returns:

A boolean tensor of shape (*batch_shape,) that is True for exit actions.

Return type:

torch.Tensor

classmethod make_dummy_actions(batch_shape, device=None, debug=False)

Creates an Actions object filled with dummy actions.

Parameters:
  • batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.

  • device (torch.device | None) – The device to create the actions on.

  • debug (bool) – Whether to run debug validations on the constructed Actions.

Returns:

An Actions object with the specified batch shape filled with dummy actions.

Return type:

Actions

classmethod make_exit_actions(batch_shape, device=None, debug=False)

Creates an Actions object filled with exit actions.

Parameters:
  • batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.

  • device (torch.device | None) – The device to create the actions on.

  • debug (bool) – Whether to run debug validations on the constructed Actions.

Returns:

An Actions object with the specified batch shape filled with exit actions.

Return type:

Actions

classmethod stack(actions_list, debug=None)

Stacks a list of Actions objects along a new dimension (0).

The individual actions need to have the same batch shape. An example application is when the individual actions represent per-step actions of a batch of trajectories (in which case, the common batch_shape would be (batch_size,), and the resulting Actions object would have batch_shape (n_steps, batch_size)).

Parameters:
  • actions_list (List[Actions]) – List of Actions objects to stack.

  • debug (bool | None)

Returns:

A new Actions object with the stacked actions.

Return type:

Actions

tensor
class gfn.actions.GraphActionType

Bases: enum.IntEnum

Enum where members are also (and must be) ints

ADD_EDGE = 1
ADD_NODE = 0
DUMMY = 3
EXIT = 2
class gfn.actions.GraphActions(tensor, debug=False)

Bases: Actions

Actions for graph-based environments.

Each action is one of these types: - ADD_NODE: Add a node with given features - ADD_EDGE: Add an edge between two nodes with given features - EXIT: Terminate the trajectory

Parameters:
  • tensor (torch.Tensor)

  • debug (bool)

tensor

Tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components.

ACTION_TYPE_KEY

Class variable, key for the action type component.

NODE_CLASS_KEY

Class variable, key for the node class component.

EDGE_CLASS_KEY

Class variable, key for the edge class component.

EDGE_INDEX_KEY

Class variable, key for the edge index component.

ACTION_INDICES

Class variable, mapping from keys to tensor indices.

ACTION_INDICES: ClassVar[dict[str, int]]
ACTION_TYPE_KEY: ClassVar[str] = 'action_type'
EDGE_CLASS_KEY: ClassVar[str] = 'edge_class'
EDGE_INDEX_KEY: ClassVar[str] = 'edge_index'
NODE_CLASS_KEY: ClassVar[str] = 'node_class'
NODE_INDEX_KEY: ClassVar[str] = 'node_index'
__repr__()

Returns a string representation of the GraphActions object.

Returns:

A string summary of the GraphActions object.

action_shape = (5,)
property action_type: torch.Tensor

Returns the action type tensor.

Returns:

A tensor of shape (*batch_shape,) containing the action types.

Return type:

torch.Tensor

property batch_shape: tuple[int, Ellipsis]

The batch shape of the graph actions.

Returns:

The batch shape as a tuple.

Return type:

tuple[int, Ellipsis]

debug = False
dummy_action
property edge_class: torch.Tensor

Returns the edge class tensor.

Returns:

A tensor of shape (*batch_shape,) containing the edge classes.

Return type:

torch.Tensor

property edge_index: torch.Tensor

Returns the edge index tensor.

Returns:

A tensor of shape (*batch_shape,) containing the edge indices.

Return type:

torch.Tensor

classmethod edge_index_action_to_src_dst(edge_index_action, n_nodes)
Abstractmethod:

Parameters:
  • edge_index_action (torch.Tensor)

  • n_nodes (int)

Return type:

tuple[torch.Tensor, torch.Tensor]

Converts the edge index action to source and destination node indices.

exit_action
classmethod from_tensor_dict(tensor_dict, debug=False)

Creates a GraphActions object from a tensor dict.

Parameters:
  • tensor_dict (tensordict.TensorDict) – A TensorDict containing the action components with keys ACTION_TYPE_KEY, NODE_CLASS_KEY, NODE_INDEX_KEY, EDGE_CLASS_KEY, and EDGE_INDEX_KEY.

  • debug (bool)

Returns:

A GraphActions object constructed from the tensor dict.

Return type:

GraphActions

property is_dummy: torch.Tensor

Returns a boolean tensor indicating whether the actions are dummy actions.

Returns:

A boolean tensor of shape (*batch_shape,) that is True for dummy actions.

Return type:

torch.Tensor

property is_exit: torch.Tensor

Returns a boolean tensor indicating whether the actions are exit actions.

Returns:

A boolean tensor of shape (*batch_shape,) that is True for exit actions.

Return type:

torch.Tensor

classmethod make_dummy_actions(batch_shape, device=None, debug=False)

Creates a GraphActions object filled with dummy actions.

Parameters:
  • batch_shape (tuple[int]) – Shape of the batch dimensions.

  • device (torch.device | None) – The device to create the actions on.

  • debug (bool)

Returns:

A GraphActions object with the specified batch shape filled with dummy actions.

Return type:

GraphActions

classmethod make_exit_actions(batch_shape, device=None, debug=False)

Creates a GraphActions object filled with exit actions.

Parameters:
  • batch_shape (tuple[int]) – Shape of the batch dimensions.

  • device (torch.device | None) – The device to create the actions on.

  • debug (bool)

Returns:

A GraphActions object with the specified batch shape filled with exit actions.

Return type:

GraphActions

property node_class: torch.Tensor

Returns the node class tensor.

Returns:

A tensor of shape (*batch_shape,) containing the node classes.

Return type:

torch.Tensor

property node_index: torch.Tensor

Returns the node index tensor.

Returns:

A tensor of shape (*batch_shape,) containing the node indices.

Return type:

torch.Tensor

tensor