gfn.actions =========== .. py:module:: gfn.actions Classes ------- .. autoapisummary:: gfn.actions.Actions gfn.actions.GraphActionType gfn.actions.GraphActions Module Contents --------------- .. py:class:: Actions(tensor, debug = False) Bases: :py:obj:`abc.ABC` Base class for actions, representing edges in the DAG of a GFlowNet. Each environment needs to define a subclass of `Actions` to represent its specific action space. Two useful subclasses of `Actions` are provided: - `DiscreteActions` for discrete environments, which represents actions as a tensor of shape (*batch_shape, *action_shape). - `GraphActions` for graph-based environments, which represents actions as a tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components. .. attribute:: tensor Tensor of shape (*batch_shape, *action_shape) representing a batch of actions. .. attribute:: action_shape Class variable, a tuple defining the shape of a single action. .. attribute:: dummy_action Class variable, a tensor of shape (*action_shape,) representing the dummy action for padding shorter trajectories. .. attribute:: exit_action Class variable, a tensor of shape (*action_shape,) representing the action to transition to the sink state. .. py:method:: __getitem__(index) Returns a subset of the actions along the batch dimension. :param index: Indices to select actions. :returns: A new Actions object with the selected actions. .. py:method:: __len__() Returns the number of actions in the batch. :returns: The number of actions. .. py:method:: __repr__() Returns a string representation of the Actions object. :returns: A string summary of the Actions object. .. py:method:: __setitem__(index, actions) Sets particular actions of the batch to a new Actions object. :param index: Indices to set. :param actions: Actions object containing the new actions. .. py:method:: _compare(other) Compares the actions to a tensor of actions. :param other: Tensor of actions to compare, with shape (*batch_shape, *action_shape). :returns: A boolean tensor of shape (*batch_shape,) indicating whether the actions are equal. .. py:attribute:: action_shape :type: ClassVar[tuple[int, Ellipsis]] .. py:property:: batch_shape :type: tuple[int, Ellipsis] The batch shape of the actions. :returns: The batch shape as a tuple. .. py:method:: clone() Returns a clone of the Actions object. :returns: A new Actions object with the same tensor. .. py:attribute:: debug :value: False .. py:property:: device :type: torch.device The device on which the actions are stored. :returns: The device of the underlying tensor. .. py:attribute:: dummy_action :type: ClassVar[torch.Tensor] .. py:attribute:: exit_action :type: ClassVar[torch.Tensor] .. py:method:: extend(other) Concatenates another Actions object along the final batch dimension. Both Actions objects must have the same number of batch dimensions, which should be 1 or 2. :param other: Actions object to be concatenated to the current Actions object. .. py:method:: extend_with_dummy_actions(required_first_dim) Extends an Actions instance along the first dimension with dummy actions. The Actions instance batch_shape must be 2-dimensional. This is used to pad actions in a batch of trajectories to a common length. :param required_first_dim: The target size of the first dimension post expansion. .. py:property:: is_dummy :type: torch.Tensor Returns a boolean tensor indicating whether the actions are dummy actions. :returns: A boolean tensor of shape (*batch_shape,) that is True for dummy actions. .. py:property:: is_exit :type: torch.Tensor Returns a boolean tensor indicating whether the actions are exit actions. :returns: A boolean tensor of shape (*batch_shape,) that is True for exit actions. .. py:method:: make_dummy_actions(batch_shape, device = None, debug = False) :classmethod: Creates an Actions object filled with dummy actions. :param batch_shape: Shape of the batch dimensions. :param device: The device to create the actions on. :param debug: Whether to run debug validations on the constructed Actions. :returns: An Actions object with the specified batch shape filled with dummy actions. .. py:method:: make_exit_actions(batch_shape, device = None, debug = False) :classmethod: Creates an Actions object filled with exit actions. :param batch_shape: Shape of the batch dimensions. :param device: The device to create the actions on. :param debug: Whether to run debug validations on the constructed Actions. :returns: An Actions object with the specified batch shape filled with exit actions. .. py:method:: stack(actions_list, debug = None) :classmethod: Stacks a list of Actions objects along a new dimension (0). The individual actions need to have the same batch shape. An example application is when the individual actions represent per-step actions of a batch of trajectories (in which case, the common batch_shape would be (batch_size,), and the resulting Actions object would have batch_shape (n_steps, batch_size)). :param actions_list: List of Actions objects to stack. :returns: A new Actions object with the stacked actions. .. py:attribute:: tensor .. py:class:: GraphActionType Bases: :py:obj:`enum.IntEnum` Enum where members are also (and must be) ints .. py:attribute:: ADD_EDGE :value: 1 .. py:attribute:: ADD_NODE :value: 0 .. py:attribute:: DUMMY :value: 3 .. py:attribute:: EXIT :value: 2 .. py:class:: GraphActions(tensor, debug = False) Bases: :py:obj:`Actions` Actions for graph-based environments. Each action is one of these types: - ADD_NODE: Add a node with given features - ADD_EDGE: Add an edge between two nodes with given features - EXIT: Terminate the trajectory .. attribute:: tensor Tensor of shape (*batch_shape, 4) containing the action type, node class, edge class, and edge index components. .. attribute:: ACTION_TYPE_KEY Class variable, key for the action type component. .. attribute:: NODE_CLASS_KEY Class variable, key for the node class component. .. attribute:: EDGE_CLASS_KEY Class variable, key for the edge class component. .. attribute:: EDGE_INDEX_KEY Class variable, key for the edge index component. .. attribute:: ACTION_INDICES Class variable, mapping from keys to tensor indices. .. py:attribute:: ACTION_INDICES :type: ClassVar[dict[str, int]] .. py:attribute:: ACTION_TYPE_KEY :type: ClassVar[str] :value: 'action_type' .. py:attribute:: EDGE_CLASS_KEY :type: ClassVar[str] :value: 'edge_class' .. py:attribute:: EDGE_INDEX_KEY :type: ClassVar[str] :value: 'edge_index' .. py:attribute:: NODE_CLASS_KEY :type: ClassVar[str] :value: 'node_class' .. py:attribute:: NODE_INDEX_KEY :type: ClassVar[str] :value: 'node_index' .. py:method:: __repr__() Returns a string representation of the GraphActions object. :returns: A string summary of the GraphActions object. .. py:attribute:: action_shape :value: (5,) .. py:property:: action_type :type: torch.Tensor Returns the action type tensor. :returns: A tensor of shape (*batch_shape,) containing the action types. .. py:property:: batch_shape :type: tuple[int, Ellipsis] The batch shape of the graph actions. :returns: The batch shape as a tuple. .. py:attribute:: debug :value: False .. py:attribute:: dummy_action .. py:property:: edge_class :type: torch.Tensor Returns the edge class tensor. :returns: A tensor of shape (*batch_shape,) containing the edge classes. .. py:property:: edge_index :type: torch.Tensor Returns the edge index tensor. :returns: A tensor of shape (*batch_shape,) containing the edge indices. .. py:method:: edge_index_action_to_src_dst(edge_index_action, n_nodes) :classmethod: :abstractmethod: Converts the edge index action to source and destination node indices. .. py:attribute:: exit_action .. py:method:: from_tensor_dict(tensor_dict, debug = False) :classmethod: Creates a GraphActions object from a tensor dict. :param tensor_dict: A TensorDict containing the action components with keys ACTION_TYPE_KEY, NODE_CLASS_KEY, NODE_INDEX_KEY, EDGE_CLASS_KEY, and EDGE_INDEX_KEY. :returns: A GraphActions object constructed from the tensor dict. .. py:property:: is_dummy :type: torch.Tensor Returns a boolean tensor indicating whether the actions are dummy actions. :returns: A boolean tensor of shape (*batch_shape,) that is True for dummy actions. .. py:property:: is_exit :type: torch.Tensor Returns a boolean tensor indicating whether the actions are exit actions. :returns: A boolean tensor of shape (*batch_shape,) that is True for exit actions. .. py:method:: make_dummy_actions(batch_shape, device = None, debug = False) :classmethod: Creates a GraphActions object filled with dummy actions. :param batch_shape: Shape of the batch dimensions. :param device: The device to create the actions on. :returns: A GraphActions object with the specified batch shape filled with dummy actions. .. py:method:: make_exit_actions(batch_shape, device = None, debug = False) :classmethod: Creates a GraphActions object filled with exit actions. :param batch_shape: Shape of the batch dimensions. :param device: The device to create the actions on. :returns: A GraphActions object with the specified batch shape filled with exit actions. .. py:property:: node_class :type: torch.Tensor Returns the node class tensor. :returns: A tensor of shape (*batch_shape,) containing the node classes. .. py:property:: node_index :type: torch.Tensor Returns the node index tensor. :returns: A tensor of shape (*batch_shape,) containing the node indices. .. py:attribute:: tensor