gfn.actions¶
Classes¶
Base class for actions, representing edges in the DAG of a GFlowNet. |
|
Enum where members are also (and must be) ints |
|
Actions for graph-based environments. |
Module Contents¶
- class gfn.actions.Actions(tensor, debug=False)¶
Bases:
abc.ABCBase class for actions, representing edges in the DAG of a GFlowNet.
Each environment needs to define a subclass of Actions to represent its specific action space.
Two useful subclasses of Actions are provided: - DiscreteActions for discrete environments, which represents actions as a tensor
GraphActions for graph-based environments, which represents actions as a tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components.
- Parameters:
tensor (torch.Tensor)
debug (bool)
- action_shape¶
Class variable, a tuple defining the shape of a single action.
- dummy_action¶
Class variable, a tensor of shape (*action_shape,) representing the dummy action for padding shorter trajectories.
- exit_action¶
Class variable, a tensor of shape (*action_shape,) representing the action to transition to the sink state.
- __getitem__(index)¶
Returns a subset of the actions along the batch dimension.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to select actions.
- Returns:
A new Actions object with the selected actions.
- Return type:
- __len__()¶
Returns the number of actions in the batch.
- Returns:
The number of actions.
- Return type:
int
- __repr__()¶
Returns a string representation of the Actions object.
- Returns:
A string summary of the Actions object.
- __setitem__(index, actions)¶
Sets particular actions of the batch to a new Actions object.
- Parameters:
index (int | slice | tuple | Sequence[int] | Sequence[bool] | torch.Tensor) – Indices to set.
actions (Actions) – Actions object containing the new actions.
- Return type:
None
- _compare(other)¶
Compares the actions to a tensor of actions.
- action_shape: ClassVar[tuple[int, Ellipsis]]¶
- property batch_shape: tuple[int, Ellipsis]¶
The batch shape of the actions.
- Returns:
The batch shape as a tuple.
- Return type:
tuple[int, Ellipsis]
- clone()¶
Returns a clone of the Actions object.
- Returns:
A new Actions object with the same tensor.
- Return type:
- debug = False¶
- property device: torch.device¶
The device on which the actions are stored.
- Returns:
The device of the underlying tensor.
- Return type:
torch.device
- dummy_action: ClassVar[torch.Tensor]¶
- exit_action: ClassVar[torch.Tensor]¶
- extend(other)¶
Concatenates another Actions object along the final batch dimension.
Both Actions objects must have the same number of batch dimensions, which should be 1 or 2.
- Parameters:
other (Actions) – Actions object to be concatenated to the current Actions object.
- Return type:
None
- extend_with_dummy_actions(required_first_dim)¶
Extends an Actions instance along the first dimension with dummy actions.
The Actions instance batch_shape must be 2-dimensional. This is used to pad actions in a batch of trajectories to a common length.
- Parameters:
required_first_dim (int) – The target size of the first dimension post expansion.
- Return type:
None
- property is_dummy: torch.Tensor¶
Returns a boolean tensor indicating whether the actions are dummy actions.
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for dummy actions.
- Return type:
torch.Tensor
- property is_exit: torch.Tensor¶
Returns a boolean tensor indicating whether the actions are exit actions.
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for exit actions.
- Return type:
torch.Tensor
- classmethod make_dummy_actions(batch_shape, device=None, debug=False)¶
Creates an Actions object filled with dummy actions.
- Parameters:
batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.
device (torch.device | None) – The device to create the actions on.
debug (bool) – Whether to run debug validations on the constructed Actions.
- Returns:
An Actions object with the specified batch shape filled with dummy actions.
- Return type:
- classmethod make_exit_actions(batch_shape, device=None, debug=False)¶
Creates an Actions object filled with exit actions.
- Parameters:
batch_shape (tuple[int, Ellipsis]) – Shape of the batch dimensions.
device (torch.device | None) – The device to create the actions on.
debug (bool) – Whether to run debug validations on the constructed Actions.
- Returns:
An Actions object with the specified batch shape filled with exit actions.
- Return type:
- classmethod stack(actions_list, debug=None)¶
Stacks a list of Actions objects along a new dimension (0).
The individual actions need to have the same batch shape. An example application is when the individual actions represent per-step actions of a batch of trajectories (in which case, the common batch_shape would be (batch_size,), and the resulting Actions object would have batch_shape (n_steps, batch_size)).
- tensor¶
- class gfn.actions.GraphActionType¶
Bases:
enum.IntEnumEnum where members are also (and must be) ints
- ADD_EDGE = 1¶
- ADD_NODE = 0¶
- DUMMY = 3¶
- EXIT = 2¶
- class gfn.actions.GraphActions(tensor, debug=False)¶
Bases:
ActionsActions for graph-based environments.
Each action is one of these types: - ADD_NODE: Add a node with given features - ADD_EDGE: Add an edge between two nodes with given features - EXIT: Terminate the trajectory
- Parameters:
tensor (torch.Tensor)
debug (bool)
- tensor¶
Tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components.
- ACTION_TYPE_KEY¶
Class variable, key for the action type component.
- NODE_CLASS_KEY¶
Class variable, key for the node class component.
- EDGE_CLASS_KEY¶
Class variable, key for the edge class component.
- EDGE_INDEX_KEY¶
Class variable, key for the edge index component.
- ACTION_INDICES¶
Class variable, mapping from keys to tensor indices.
- ACTION_INDICES: ClassVar[dict[str, int]]¶
- ACTION_TYPE_KEY: ClassVar[str] = 'action_type'¶
- EDGE_CLASS_KEY: ClassVar[str] = 'edge_class'¶
- EDGE_INDEX_KEY: ClassVar[str] = 'edge_index'¶
- NODE_CLASS_KEY: ClassVar[str] = 'node_class'¶
- NODE_INDEX_KEY: ClassVar[str] = 'node_index'¶
- __repr__()¶
Returns a string representation of the GraphActions object.
- Returns:
A string summary of the GraphActions object.
- action_shape = (5,)¶
- property action_type: torch.Tensor¶
Returns the action type tensor.
- Returns:
A tensor of shape (*batch_shape,) containing the action types.
- Return type:
torch.Tensor
- property batch_shape: tuple[int, Ellipsis]¶
The batch shape of the graph actions.
- Returns:
The batch shape as a tuple.
- Return type:
tuple[int, Ellipsis]
- debug = False¶
- dummy_action¶
- property edge_class: torch.Tensor¶
Returns the edge class tensor.
- Returns:
A tensor of shape (*batch_shape,) containing the edge classes.
- Return type:
torch.Tensor
- property edge_index: torch.Tensor¶
Returns the edge index tensor.
- Returns:
A tensor of shape (*batch_shape,) containing the edge indices.
- Return type:
torch.Tensor
- classmethod edge_index_action_to_src_dst(edge_index_action, n_nodes)¶
- Abstractmethod:
- Parameters:
edge_index_action (torch.Tensor)
n_nodes (int)
- Return type:
tuple[torch.Tensor, torch.Tensor]
Converts the edge index action to source and destination node indices.
- exit_action¶
- classmethod from_tensor_dict(tensor_dict, debug=False)¶
Creates a GraphActions object from a tensor dict.
- Parameters:
tensor_dict (tensordict.TensorDict) – A TensorDict containing the action components with keys ACTION_TYPE_KEY, NODE_CLASS_KEY, NODE_INDEX_KEY, EDGE_CLASS_KEY, and EDGE_INDEX_KEY.
debug (bool)
- Returns:
A GraphActions object constructed from the tensor dict.
- Return type:
- property is_dummy: torch.Tensor¶
Returns a boolean tensor indicating whether the actions are dummy actions.
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for dummy actions.
- Return type:
torch.Tensor
- property is_exit: torch.Tensor¶
Returns a boolean tensor indicating whether the actions are exit actions.
- Returns:
A boolean tensor of shape (*batch_shape,) that is True for exit actions.
- Return type:
torch.Tensor
- classmethod make_dummy_actions(batch_shape, device=None, debug=False)¶
Creates a GraphActions object filled with dummy actions.
- Parameters:
batch_shape (tuple[int]) – Shape of the batch dimensions.
device (torch.device | None) – The device to create the actions on.
debug (bool)
- Returns:
A GraphActions object with the specified batch shape filled with dummy actions.
- Return type:
- classmethod make_exit_actions(batch_shape, device=None, debug=False)¶
Creates a GraphActions object filled with exit actions.
- Parameters:
batch_shape (tuple[int]) – Shape of the batch dimensions.
device (torch.device | None) – The device to create the actions on.
debug (bool)
- Returns:
A GraphActions object with the specified batch shape filled with exit actions.
- Return type:
- property node_class: torch.Tensor¶
Returns the node class tensor.
- Returns:
A tensor of shape (*batch_shape,) containing the node classes.
- Return type:
torch.Tensor
- property node_index: torch.Tensor¶
Returns the node index tensor.
- Returns:
A tensor of shape (*batch_shape,) containing the node indices.
- Return type:
torch.Tensor
- tensor¶