gfn.states ========== .. py:module:: gfn.states Attributes ---------- .. autoapisummary:: gfn.states.logger Classes ------- .. autoapisummary:: gfn.states.DiscreteStates gfn.states.GraphStates gfn.states.States Functions --------- .. autoapisummary:: gfn.states._assert_factory_accepts_debug Module Contents --------------- .. py:class:: DiscreteStates(tensor, conditions = None, device = None, debug = False) Bases: :py:obj:`States`, :py:obj:`abc.ABC` Base class for states of discrete environments. DiscreteStates provide `forward_masks` and `backward_masks` as cached properties that compute which actions are allowed at each state on demand. This approach (similar to GraphStates) makes slicing operations faster since masks don't need to be sliced - they are recomputed only when accessed. Subclasses must implement `_compute_forward_masks` and `_compute_backward_masks` to define the mask computation logic for their specific environment. .. attribute:: n_actions Number of possible actions. .. attribute:: device The device on which the states are stored. .. attribute:: forward_masks Property that returns boolean tensor of allowed forward actions. .. attribute:: backward_masks Property that returns boolean tensor of allowed backward actions. Compile-related expectations: - Inputs (state tensor) should already be on the target device with correct shapes; debug can be used to validate during development/tests. - Masks are computed on-demand and cached; cache is invalidated when needed. .. py:method:: __getitem__(index) Returns a subset of the discrete states. Masks are computed on demand for the new states rather than being sliced, which makes this operation faster. :param index: Indices to select states. :returns: A new DiscreteStates object with the selected states and conditions. .. py:method:: __repr__() Returns a detailed string representation of the DiscreteStates object. :returns: A string summary of the DiscreteStates object. .. py:method:: __setitem__(index, states) Sets particular discrete states. :param index: Indices to set. :param states: DiscreteStates object containing the new states. .. py:attribute:: _backward_masks_cache :type: Optional[torch.Tensor] :value: None .. py:method:: _compute_backward_masks() Computes backward action masks for the current states. By default, all backward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic. :returns: Boolean tensor of shape (*batch_shape, n_actions - 1). .. py:method:: _compute_forward_masks() Computes forward action masks for the current states. By default, all forward actions are allowed. Typically, this method should be overridden by subclasses to define environment-specific mask logic. :returns: Boolean tensor of shape (*batch_shape, n_actions). .. py:attribute:: _forward_masks_cache :type: Optional[torch.Tensor] :value: None .. py:method:: _invalidate_masks_cache() Invalidates the cached masks, forcing recomputation on next access. .. py:method:: _make_view(tensor, conditions = None, debug = False) :classmethod: Fast constructor for slicing: extends base _make_view with mask cache initialization (left empty for on-demand computation). .. py:property:: backward_masks :type: torch.Tensor Returns backward action masks, computing and caching if needed. :returns: Boolean tensor of shape (*batch_shape, n_actions - 1) indicating which backward actions are allowed at each state. .. py:method:: clone() Returns a clone of the current instance. :returns: A new DiscreteStates object with the same data and conditions. Masks are recomputed on demand for the cloned states. .. py:method:: extend(other) Concatenates another DiscreteStates object along the batch dimension. :param other: DiscreteStates object to concatenate with. .. py:method:: flatten() Flattens the batch dimension of the discrete states. Masks are computed on demand for the flattened states. :returns: A new DiscreteStates object with the batch dimension flattened. .. py:property:: forward_masks :type: torch.Tensor Returns forward action masks, computing and caching if needed. :returns: Boolean tensor of shape (*batch_shape, n_actions) indicating which forward actions are allowed at each state. .. py:method:: init_forward_masks(set_ones = True) Initializes forward masks. A convenience function for common mask operations. :param set_ones: if True, forward masks are initialized to all ones. Otherwise, they are initialized to all zeros. .. py:attribute:: n_actions :type: ClassVar[int] .. py:method:: pad_dim0_with_sf(required_first_dim) Extends states along the first batch dimension with sink states. Given a batch of states (i.e. of `batch_shape=(a, b)`), extends `a` to a DiscreteStates object of `batch_shape = (required_first_dim, b)`, by adding the required number of $s_f$ tensors. This is useful to extend trajectories of different lengths. :param required_first_dim: The size of the first batch dimension post-expansion. .. py:method:: set_exit_masks(batch_idx) Sets forward masks such that the only allowable next action is to exit. A convenience function for common mask operations. :param batch_idx: A boolean index along the batch dimension, along which to enforce exits. .. rubric:: Notes - Works for 1D or 2D batch shapes; `batch_idx` must match `batch_shape`. - Clears all actions for the selected batch entries, then sets only the exit action True via masked_fill to stay torch.compile friendly. - Does not move devices; expects masks/tensors already on the target device. .. py:method:: set_nonexit_action_masks(cond, allow_exit) Masks denoting disallowed actions according to cond, appending the exit mask. A convenience function for common mask operations. :param cond: a boolean of shape (*batch_shape,) + (n_actions - 1,), which denotes which actions are *not* allowed. For example, if a state element represents action count, and no action can be repeated more than 5 times, cond might be state.tensor > 5 (assuming count starts at 0). :param allow_exit: sets whether exiting can happen at any point in the trajectory - if so, it should be set to True. .. rubric:: Notes - Always resets `forward_masks` to all True before applying the new mask so updates do not leak across steps. - Works for 1D or 2D batch shapes; cond must match `batch_shape`. - Debug guards validate shape/dtype but should be off in compiled regions. .. py:method:: stack(states) :classmethod: Stacks a list of DiscreteStates objects along a new dimension (0). :param states: List of DiscreteStates objects to stack. :returns: A new DiscreteStates object with the stacked states and conditions. Masks are computed on demand for the stacked states. .. py:method:: to(device) Moves the tensor to the specified device in-place. Masks will be recomputed on the new device when accessed. :param device: The device to move to. :returns: The DiscreteStates object on the specified device. .. py:class:: GraphStates(data, categorical_node_features = False, categorical_edge_features = False, conditions = None, device = None, debug = False) Bases: :py:obj:`States` Base class for graph-based state representations. A `GraphStates` object is a collection of multiple graph objects stored as a numpy array of `GeometricData` objects. This supports batched management of graphs. .. attribute:: num_node_classes Number of node classes. .. attribute:: num_edge_classes Number of edge classes. .. attribute:: is_directed Whether the graph is directed. .. attribute:: s0 Initial state (graph). .. attribute:: sf Final state (graph). .. attribute:: data A numpy array of `GeometricData` objects representing individual graphs. .. attribute:: _device The device on which the graphs are stored. .. py:method:: __getitem__(index) Returns a subset of the GraphStates. :param index: Index or indices to select. :returns: A new GraphStates object containing the selected graphs. .. py:method:: __len__() Returns the total number of graphs. :returns: The number of graphs in the batch. .. py:method:: __repr__() Returns a detailed string representation of the GraphStates object. :returns: A string summary of the GraphStates object. .. py:method:: __setitem__(index, graph) Sets a subset of the GraphStates. :param index: Index or indices to set. :param graph: GraphStates object containing the new graphs. .. py:method:: _compare(other) Compares the current batch of graphs with another graph. Note that this does not check if the conditions are equal. :param other: A `GeometricData` object to compare with. :returns: A boolean tensor of shape (*batch_shape,) indicating which graphs in the batch are equal to `other`. .. py:method:: _compare_reference(ref) Compares batch against a reference graph (s0 or sf), handling device mismatch. .. py:attribute:: _conditions :type: torch.Tensor | None :value: None .. py:attribute:: _device .. py:method:: _get_index_np(index) Converts a tensor-based index to a numpy index. :param index: The index to convert. :returns: The converted index. .. py:property:: backward_masks :type: tensordict.TensorDict Computes masks for valid backward actions from the current state. A backward action is valid if: 1. The edge exists in the current graph (i.e., can be removed) 2. The node exists in the current graph and no edges are connected to it For directed graphs, all existing edges are considered for removal. For undirected graphs, only the upper triangular edges are considered. The EXIT action is not included in backward masks. :returns: Boolean mask where True indicates valid actions. :rtype: TensorDict .. py:property:: batch_shape :type: tuple[int, Ellipsis] The batch shape of the graphs. :returns: The batch shape as a tuple. .. py:attribute:: categorical_edge_features :value: False .. py:attribute:: categorical_node_features :value: False .. py:method:: clone() Returns a detached clone of the current instance. :returns: A new GraphStates object with the same data. .. py:attribute:: data .. py:attribute:: debug :value: False .. py:property:: device :type: torch.device The device on which the states are stored. :returns: The device of the underlying array of GeometricData. .. py:method:: extend(other) Concatenates another GraphStates object along the batch dimension. :param other: GraphStates object to concatenate with. .. py:property:: forward_masks :type: tensordict.TensorDict Computes masks for valid forward actions from the current state. A forward action is valid if: 1. The edge doesn't already exist in the graph 2. The edge connects two distinct nodes For directed graphs, all possible src->dst edges are considered. For undirected graphs, only the upper triangular portion of the adjacency matrix is used. :returns: Boolean mask where True indicates valid actions. :rtype: TensorDict .. py:attribute:: is_directed :type: ClassVar[bool] .. py:property:: is_initial_state :type: torch.Tensor Returns a boolean tensor indicating which graphs are initial states ($s_0$). :returns: A boolean tensor of shape (*batch_shape,) that is True for initial states. .. py:property:: is_sink_state :type: torch.Tensor Returns a boolean tensor indicating which graphs are sink states ($s_f$). :returns: A boolean tensor of shape (*batch_shape,) that is True for sink states. .. py:method:: make_initial_states(batch_shape, conditions = None, device = None, debug = False) :classmethod: Creates a numpy array of graphs consisting of initial states ($s_0$). :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: Device to create the graphs on. :param debug: If True, keeps compile graph-breaking checks in the logic for safety. :returns: A GraphStates object containing copies of the initial state. .. py:method:: make_sink_states(batch_shape, conditions = None, device = None, debug = False) :classmethod: Creates a numpy array of graphs consisting of sink states ($s_f$). :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: Device to create the graphs on. :param debug: If True, keeps compile graph-breaking checks in the logic for safety. :returns: A GraphStates object containing copies of the sink state. .. py:attribute:: max_nodes :type: ClassVar[int | None] .. py:attribute:: num_edge_classes :type: ClassVar[int] .. py:attribute:: num_node_classes :type: ClassVar[int] .. py:method:: pad_dim0_with_sf(required_first_dim) Extends a 2-dimensional batch of graph states along the first batch dimension. Given a batch of states (i.e. of `batch_shape=(a, b)`), extends `a` to a GraphStates object of `batch_shape = (required_first_dim, b)`, by adding the required number of $s_f$ graphs. This is useful to extend trajectories of different lengths. :param required_first_dim: The size of the first batch dimension post-expansion. .. py:attribute:: s0 :type: ClassVar[torch_geometric.data.Data] .. py:attribute:: sf :type: ClassVar[torch_geometric.data.Data] .. py:method:: stack(states) :classmethod: Stacks a list of GraphStates objects along a new dimension (0). :param states: List of GraphStates objects to stack. :returns: A new GraphStates object with the stacked graphs and conditions. .. py:property:: tensor :type: gfn.utils.graphs.GeometricBatch Returns the batch representation of the data as a GeometricBatch. :returns: A GeometricBatch object representing the batch of graphs. .. py:method:: to(device) Moves the GraphStates to the specified device. :param device: The device to move to. :returns: The GraphStates object on the specified device. .. py:class:: States(tensor, conditions = None, device = None, debug = False) Bases: :py:obj:`abc.ABC` Base class for states, representing nodes in the DAG of a GFlowNet. Each environment needs to define a subclass of `States`. A `States` object is a collection of multiple states (nodes of the DAG) that supports batching. Generally, if a state is represented with a tensor of shape (*state_shape), a batch of states is represented with a `States` object, with the attribute `tensor` of shape (*batch_shape, *state_shape). Other representations are possible (e.g., state as a string, numpy array, graph, etc.), but these may need additional logic to support batching (see `GraphStates` below for an example). Two useful subclasses of `States` are provided: - `DiscreteStates` for discrete environments, which represents discrete states with a tensor of shape (*batch_shape, *state_shape). - `GraphStates` for graph-based environments, which represents graphs as a numpy object array of shape (*batch_shape,) containing `GeometricData` objects. A `batch_shape` property keeps track of the batch dimension. A trajectory can be represented by a States object with `batch_shape = (n_states,)`. Multiple trajectories can be represented by a States object with `batch_shape = (n_states, batch_size)`. Because multiple trajectories can have different lengths, batching requires appending a dummy state ($sf$) to trajectories that are shorter than the longest trajectory. This dummy state should never be processed, and is used to pad the batch of states only. Compile-related expectations: - Hot paths should be called with tensors already on the target device and with correct shapes; debug guards can be enabled during development/tests to validate. - Set `debug=False` inside torch.compile regions to avoid Python-side graph breaks; enable `debug=True` only when running eager checks. .. attribute:: tensor Tensor of shape (*batch_shape, *state_shape) representing a batch of states. .. attribute:: state_shape Class variable, a tuple defining the shape of a single state. .. attribute:: s0 Class variable, a tensor of shape (*state_shape,) representing the initial state. .. attribute:: sf Class variable, a tensor of shape (*state_shape,) representing the sink state. .. attribute:: make_random_states Class variable, a callable that returns a random state. This is used to initialize random states. .. py:method:: __getitem__(index) Returns a subset of the states along the batch dimension. :param index: Indices to select states. :returns: A new States object with the selected states and conditions. .. py:method:: __len__() Returns the number of states in the batch. :returns: The number of states. .. py:method:: __repr__() Returns a string representation of the States object. :returns: A string summary of the States object. .. py:method:: __setitem__(index, states) Sets particular states of the batch to a new States object. :param index: Indices to set. :param states: States object containing the new states. .. py:method:: _compare(other) Computes elementwise equality between state tensor and an external tensor. Note that this does not check if the conditions are equal. :param other: Tensor with shape (*batch_shape, *state_shape) representing states to :param compare to.: :returns: A boolean tensor of shape (*batch_shape,) indicating whether the states are equal to `other`. .. py:attribute:: _conditions :type: torch.Tensor | None :value: None .. py:attribute:: _is_initial_cache :type: torch.Tensor | None :value: None .. py:attribute:: _is_sink_cache :type: torch.Tensor | None :value: None .. py:method:: _make_view(tensor, conditions = None, debug = False) :classmethod: Fast constructor for internal slicing operations. Bypasses __init__'s device resolution, .to() dispatch, and conditions shape/dtype validation — all redundant when the source States object was already validated. Used by __getitem__ to avoid per-step overhead in the sampling loop. .. py:method:: _merge_conditions(other, self_was_empty) Merges conditions after extending tensors. When self was empty its conditions are None by construction, so we adopt the other's conditions rather than warning about inconsistency. Symmetrically, extending with an empty other is a no-op for conditions. .. py:property:: batch_shape :type: tuple[int, Ellipsis] The batch shape of the states. :returns: The batch shape as a tuple. .. py:method:: clone() Returns a clone of the current instance. :returns: A new States object with the same data and conditions. .. py:property:: conditions :type: torch.Tensor | None The conditions attached to these states for conditional GFlowNets. :returns: Tensor of shape (*batch_shape, condition_dim) or None if no conditions. .. py:attribute:: debug :value: False .. py:property:: device :type: torch.device The device on which the states are stored. :returns: The device of the underlying tensor. .. py:method:: extend(other) Concatenates another States object along the final batch dimension. Both States objects must have the same number of batch dimensions, which should be 1 or 2. :param other: States object to be concatenated to the current States object. .. py:method:: flatten() Flattens the batch dimension of the states. Useful for example when extracting individual states from trajectories. :returns: A new States object with the batch dimension flattened. .. py:method:: from_batch_shape(batch_shape, random = False, sink = False, conditions = None, device = None, debug = False) :classmethod: Creates a States object with the given batch shape. By default, all states are initialized to $s_0$, the initial state. Optionally, one can initialize random state, which requires that the environment implements the `make_random_states` class method. Sink can be used to initialize states at $s_f$, the sink state. Both random and sink cannot be True at the same time. :param batch_shape: Shape of the batch dimensions. :param random: If True, initialize states randomly. :param sink: If True, initialize states as sink states ($s_f$). :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to create the states on. :param debug: If True, keeps compile graph-breaking checks in the logic for safety. :returns: A States object with the specified batch shape and initialization. .. py:property:: has_conditions :type: bool Whether conditions are attached to these states. .. py:property:: is_initial_state :type: torch.Tensor Returns a boolean tensor indicating which states are initial ($s_0$). :returns: A boolean tensor of shape (*batch_shape,) that is True for initial states. .. py:property:: is_sink_state :type: torch.Tensor Returns a boolean tensor indicating which states are sink ($s_f$). :returns: A boolean tensor of shape (*batch_shape,) that is True for sink states. .. py:method:: make_initial_states(batch_shape, conditions = None, device = None, debug = False) :classmethod: Creates a States object with all states set to $s_0$. :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to create the states on. :param debug: If True, keeps compile graph-breaking checks in the logic for safety. :returns: A States object with all states set to $s_0$. .. py:attribute:: make_random_states :type: Callable .. py:method:: make_sink_states(batch_shape, conditions = None, device = None, debug = False) :classmethod: Creates a States object with all states set to $s_f$. :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to create the states on. :param debug: If True, keeps compile graph-breaking checks in the logic for safety. :returns: A States object with all states set to $s_f$. .. py:method:: pad_dim0_with_sf(required_first_dim) Extends a 2-dimensional batch of states along the first batch dimension. Given a batch of states (i.e. of `batch_shape=(a, b)`), extends `a` to a States object of `batch_shape = (required_first_dim, b)`, by adding the required number of $s_f$ tensors. This is useful to extend trajectories of different lengths. :param required_first_dim: The size of the first batch dimension post-expansion. .. py:attribute:: s0 :type: ClassVar[torch.Tensor | torch_geometric.data.Data] .. py:method:: sample(n_samples) Randomly samples a subset of states from the batch. :param n_samples: The number of states to sample. :returns: A new States object with the sampled states. .. py:attribute:: sf :type: ClassVar[torch.Tensor | torch_geometric.data.Data] .. py:method:: stack(states) :classmethod: Stacks a list of States objects along a new dimension (0). :param states: List of States objects to stack. :returns: A new States object with the stacked states and conditions. .. py:attribute:: state_shape :type: ClassVar[tuple[int, Ellipsis]] .. py:attribute:: tensor .. py:method:: to(device) Moves the States tensor to the specified device in-place. :param device: The device to move to. :returns: The States object on the specified device. .. py:function:: _assert_factory_accepts_debug(factory, factory_name) Ensure the factory can accept a debug kwarg (explicit or via **kwargs). .. py:data:: logger