gfn.env ======= .. py:module:: gfn.env Attributes ---------- .. autoapisummary:: gfn.env.NonValidActionsError Classes ------- .. autoapisummary:: gfn.env.DiscreteEnv gfn.env.Env gfn.env.GraphEnv Module Contents --------------- .. py:class:: DiscreteEnv(n_actions, s0, state_shape, dummy_action = None, exit_action = None, sf = None, debug = False) Bases: :py:obj:`Env`, :py:obj:`abc.ABC` Base class for discrete environments, where states are defined in a discrete space, and actions are represented by an integer in {0, ..., n_actions - 1}, the last one being the exit action. For a guide on creating your own environments, see the documentation at: :doc:`guides/creating_environments`. For a complete example, see the HyperGrid environment in `src/gfn/gym/hypergrid.py`. .. attribute:: s0 Tensor of shape (*state_shape) representing the initial state. .. attribute:: sf Tensor of shape (*state_shape) representing the sink (final) state. .. attribute:: n_actions The number of actions in the environment. .. attribute:: state_shape Tuple representing the shape of the states. .. attribute:: dummy_action Tensor of shape (1,) representing the dummy action. .. attribute:: exit_action Tensor of shape (1,) representing the exit action. .. attribute:: States The States class associated with this environment. .. attribute:: Actions The Actions class associated with this environment. .. attribute:: is_discrete Class variable, whether the environment is discrete. .. py:method:: _add_logz_diff(validation_info, gflownet, validate_condition) Compute |learned_logZ - true_logZ| and add to validation_info. .. py:method:: _backward_step(states, actions) Wrapper for the user-defined `backward_step` function. :param states: The batch of discrete states. :param actions: The batch of actions. :returns: The batch of previous discrete states. .. py:method:: _jsd(p, q) :staticmethod: Jensen-Shannon divergence between two discrete distributions. Uses the convention 0 * log(0 / x) = 0 via masking so that zero- probability bins contribute nothing (no clamping, no mass distortion). :param p: First distribution (1-D, sums to ~1). :param q: Second distribution (1-D, sums to ~1). :returns: JSD in nats (base-e). Bounded in [0, ln(2)]. .. py:method:: _step(states, actions) Wrapper for the user-defined `step` function. :param states: The batch of discrete states. :param actions: The batch of actions. :returns: The batch of next discrete states. .. py:method:: _warn_if_insufficient_samples(n_validation_samples) Emit a warning if validation sample count is too low for the state space. .. py:property:: all_states :type: gfn.states.DiscreteStates :abstractmethod: Optional method to return a batch of all discrete states in the environment. :returns: A batch of all discrete states (batch_shape = (n_states,)). .. note:: self.get_states_indices(self.all_states) and torch.arange(self.n_states) should be equivalent. .. py:method:: get_states_indices(states) :abstractmethod: Optional method to return the indices of the states in the environment. Most implementations return a ``torch.Tensor`` of shape ``(*batch_shape,)`` with dtype ``torch.int64``. Implementations whose canonical index space exceeds int64 (e.g. :class:`gfn.gym.HyperGrid` with ``height ** ndim > 2 ** 63``) may instead return a ``numpy.ndarray`` of dtype ``object`` containing arbitrary-precision Python ints — in that regime an int64 tensor would silently overflow and produce hash collisions between distinct states. :param states: The batch of states. :returns: Tensor or numpy object array of shape ``(*batch_shape,)`` containing the canonical indices of the states. .. py:method:: get_terminating_state_dist(states) Computes the empirical distribution over terminating states. Uses vectorized ``scatter_add_`` for efficient histogram computation. :param states: A batch of terminating ``DiscreteStates``. :returns: A 1D CPU tensor of shape ``(n_terminating_states,)`` with empirical frequencies summing to 1. :raises NotImplementedError: If the environment lacks ``get_terminating_states_indices`` or ``n_terminating_states``. :raises ValueError: If *states* is empty, or if the environment's state space is too large to histogram (``get_terminating_states_indices`` returned something other than a ``torch.Tensor``). .. py:method:: get_terminating_states_indices(states) :abstractmethod: Optional method to return the indices of the terminating states in the environment. See :meth:`get_states_indices` for the return-type contract. :param states: The batch of states. :returns: Tensor or numpy object array of shape ``(*batch_shape,)`` containing the canonical indices of the terminating states. .. py:method:: is_action_valid(states, actions, backward = False) Checks whether the actions are valid in the given discrete states. :param states: The batch of discrete states. :param actions: The batch of actions. :param backward: If True, checks validity for backward actions. :returns: True if all actions are valid in the given states, False otherwise. When `debug` is False, returns True without checking to keep hot paths compile-friendly. .. py:attribute:: is_discrete :type: bool :value: True .. py:method:: make_actions_class() Returns the Actions class for this environment. :returns: A type of a subclass of Actions with environment-specific functionalities. .. py:method:: make_states_class() Returns the DiscreteStates class for this environment. :returns: A type of a subclass of DiscreteStates with environment-specific functionalities. .. py:attribute:: n_actions .. py:property:: n_states :type: int :abstractmethod: Optional method to return the number of states in the environment. :returns: The number of states. .. py:property:: n_terminating_states :type: int :abstractmethod: Optional method to return the number of terminating states in the environment. :returns: The number of terminating states. .. py:method:: reset(batch_shape, random = False, sink = False, seed = None, conditions = None) Instantiates a batch of random, initial, or sink states. :param batch_shape: Shape of the batch (int or tuple). :param random: If True, initialize states randomly. :param sink: If True, initialize states as sink states ($s_f$). :param seed: (Optional) Random seed for reproducibility. :param conditions: (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions. :returns: A batch of initial or sink states. .. py:attribute:: s0 :type: torch.Tensor .. py:attribute:: sf :type: torch.Tensor .. py:method:: states_from_batch_shape(batch_shape, random = False, sink = False, conditions = None) Returns a batch of random, initial, or sink states with a given batch shape. :param batch_shape: Tuple representing the shape of the batch of states. :param random: If True, initialize states randomly. :param sink: If True, initialize states as sink states ($s_f$). :returns: A batch of random, initial, or sink states. :rtype: DiscreteStates .. py:method:: states_from_tensor(tensor, conditions = None) Wraps the supplied tensor in a DiscreteStates instance. :param tensor: Tensor of shape (*batch_shape, *state_shape) representing the states. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :returns: An instance of DiscreteStates. .. py:property:: terminating_states :type: gfn.states.DiscreteStates :abstractmethod: Optional method to return a batch of all terminating states in the environment. :returns: A batch of all terminating states (batch_shape = (n_terminating_states,)). .. note:: self.get_terminating_states_indices(self.terminating_states) and torch.arange(self.n_terminating_states) should be equivalent. .. py:method:: validate(gflownet, n_validation_samples = 1000, visited_terminating_states = None, validate_condition = None, sampling_chunk_size = 5000, check_sample_sufficiency = True) Evaluate a GFlowNet against this environment's true distribution. Always samples fresh from the current policy to produce an unbiased estimate. Computes L1 distance and Jensen-Shannon divergence between the empirical and true distributions. If the GFlowNet has a learned ``logZ`` and the environment implements ``log_partition``, also reports the absolute difference. :param gflownet: The GFlowNet to evaluate. :param n_validation_samples: Number of fresh trajectories to sample. :param visited_terminating_states: **Deprecated.** Ignored if passed; a ``DeprecationWarning`` is emitted. :param validate_condition: Optional condition tensor for conditional envs. :param sampling_chunk_size: Max trajectories to sample at once (avoids OOM). :param check_sample_sufficiency: If True, emits a one-time warning when ``n_validation_samples`` is too small relative to the state space. Set False to suppress. :returns: ``(metrics_dict, sampled_terminating_states)`` where *metrics_dict* contains ``"l1_dist"``, ``"jsd"``, and optionally ``"logZ_diff"``. :raises ValueError: If ``true_dist`` is unavailable, ``n_validation_samples`` is non-positive, or enumeration APIs are missing. .. py:class:: Env(s0, state_shape, action_shape, dummy_action, exit_action, sf = None, debug = False) Bases: :py:obj:`abc.ABC` Base class for all environments. Environments define the state and action spaces, as well as the forward & backward transition and reward functions. .. attribute:: s0 The initial state (tensor or GeometricData). .. attribute:: sf The sink (final) state (tensor or GeometricData). .. attribute:: state_shape Tuple representing the shape of the states. .. attribute:: action_shape Tuple representing the shape of the actions. .. attribute:: dummy_action Tensor representing the dummy action for padding. .. attribute:: exit_action Tensor representing the exit action. .. attribute:: States The States class associated with this environment. .. attribute:: Actions The Actions class associated with this environment. .. attribute:: is_discrete Class variable, whether the environment is discrete. .. attribute:: is_conditional Class variable, whether the environment is conditional. .. py:attribute:: Actions .. py:attribute:: States .. py:method:: _backward_step(states, actions) Wrapper for the user-defined `backward_step` function. This wrapper ensures that the `backward_step` is called only on valid states and actions, and sets the states to the initial state when the action is not valid. It also ensures that the new states are a distinct object from the old states. :param states: A batch of states. :param actions: A batch of actions. :returns: A batch of previous states. .. py:method:: _step(states, actions) Wrapper for the user-defined `step` function. This wrapper ensures that the `step` is called only on valid states and actions, and sets the states to the sink state when the action is exit. It also ensures that the new states are a distinct object from the old states. :param states: A batch of states. :param actions: A batch of actions. :returns: A batch of next states. .. py:attribute:: action_shape .. py:method:: actions_from_batch_shape(batch_shape) Returns a batch of dummy actions with the supplied batch shape. :param batch_shape: Tuple representing the shape of the batch of actions. :returns: A batch of dummy actions. .. py:method:: actions_from_tensor(tensor) Wraps the supplied tensor in an Actions instance. :param tensor: Tensor of shape (*action_shape) representing the actions. :returns: An Actions instance. .. py:method:: backward_step(states, actions) :abstractmethod: Backward transition function of the environment. This method takes a batch of states and actions and returns a batch of previous states. It does not need to check whether the actions are valid or the states are sink states, because the `_backward_step` method wraps it and checks for validity. :param states: A batch of states. :param actions: A batch of actions. :returns: A batch of previous states. .. py:attribute:: debug :value: False .. py:property:: device :type: torch.device The device on which the environment's elements are stored. :returns: The device of the initial state. .. py:attribute:: dummy_action .. py:attribute:: exit_action .. py:method:: is_action_valid(states, actions, backward = False) :abstractmethod: Checks whether the actions are valid in the given states. :param states: A batch of states. :param actions: A batch of actions. :param backward: If True, checks validity for backward actions. :returns: True if all actions are valid in the given states, False otherwise. .. py:attribute:: is_conditional :type: bool :value: False .. py:attribute:: is_discrete :type: bool :value: False .. py:method:: log_partition(condition = None) :abstractmethod: Optional method to return the logarithm of the partition function. :param condition: Optional tensor of shape (condition_dim,) containing the condition. :returns: The log partition function. .. py:method:: log_reward(states) Returns the environment's log of rewards for a batch of states. This or `reward` must be implemented by the environment. :param states: A batch of states with a batch_shape. :returns: Tensor of shape (*batch_shape) containing the log rewards. .. py:method:: make_actions_class() Returns the Actions class for this environment. Defines a custom Actions class that inherits from Actions and implements assumed methods. The make_actions_class method should be overwritten to achieve more environment-specific Actions functionalities. :returns: A type of a subclass of Actions with environment-specific functionalities. .. py:method:: make_random_states(batch_shape, conditions = None, device = None) :abstractmethod: Optional method to return a batch of random states. :param batch_shape: Tuple representing the shape of the batch of states. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to create the states on. :returns: A batch of random states. .. py:method:: make_states_class() Returns the States class for this environment. Defines a custom States class that inherits from States and implements assumed methods. The make_states_class method should be overwritten to achieve more environment-specific States functionalities. :returns: A type of a subclass of States with environment-specific functionalities. .. py:method:: reset(batch_shape, random = False, sink = False, seed = None, conditions = None) Instantiates a batch of random, initial, or sink states. :param batch_shape: Shape of the batch (int, tuple, or list). :param random: If True, initialize states randomly. :param sink: If True, initialize states as sink states ($s_f$). :param seed: (Optional) Random seed for reproducibility. :param conditions: (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions. :returns: A batch of initial or sink states. .. py:method:: reward(states) :abstractmethod: Returns the environment's rewards for a batch of states. This or `log_reward` must be implemented by the environment. :param states: A batch of states with a batch_shape. :returns: Tensor of shape (*batch_shape) containing the rewards. .. py:attribute:: s0 .. py:method:: sample_conditions(batch_shape) :abstractmethod: Sample conditions for the environment. Required for conditional environments. :param batch_shape: The shape of the batch of conditions to sample. :returns: A tensor of shape (*batch_shape, condition_dim) containing the conditions. .. py:attribute:: sf .. py:attribute:: state_shape .. py:method:: states_from_batch_shape(batch_shape, random = False, sink = False, conditions = None) Returns a batch of random, initial, or sink states with a given batch shape. :param batch_shape: Tuple representing the shape of the batch of states. :param random: If True, initialize states randomly (requires implementation). :param sink: If True, initialize states as sink states ($s_f$). :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :returns: A batch of random, initial, or sink states. .. py:method:: states_from_tensor(tensor, conditions = None) Wraps the supplied tensor in a States instance. :param tensor: Tensor of shape (*batch_shape, *state_shape) representing the states. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :returns: A States instance. .. py:method:: step(states, actions) :abstractmethod: Forward transition function of the environment. This method takes a batch of states and actions and returns a batch of next states. It does not need to check whether the actions are valid or the states are sink states, because the `_step` method wraps it and checks for validity. :param states: A batch of states. :param actions: A batch of actions. :returns: A batch of next states. .. py:method:: true_dist(condition = None) :abstractmethod: Optional method to return the true distribution. :param condition: Optional tensor of shape (condition_dim,) containing the condition. :returns: The true distribution as a 1-dimensional tensor. .. py:class:: GraphEnv(s0, sf, num_node_classes, num_edge_classes, is_directed, debug = False) Bases: :py:obj:`Env` Base class for graph-based environments. Graph environments represent states as graphs (torch_geometric Data objects) and actions as graph modifications. .. attribute:: s0 GeometricData representing the initial graph state. .. attribute:: sf GeometricData representing the sink (final) graph state. .. attribute:: num_node_classes Number of node classes. .. attribute:: num_edge_classes Number of edge classes. .. attribute:: is_directed Whether the graph is directed. .. attribute:: States The States class associated with this environment. .. attribute:: Actions The Actions class associated with this environment. .. py:attribute:: Actions .. py:attribute:: States .. py:method:: backward_step(states, actions) :abstractmethod: Backward transition function of the graph environment. This method takes a batch of graph states and actions and returns a batch of previous graph states. :param states: A batch of graph states. :param actions: A batch of graph actions. :returns: A batch of previous graph states. .. py:attribute:: debug :value: False .. py:property:: device :type: torch.device The device on which the graph states are stored. :returns: The device of the initial graph state's node features. .. py:attribute:: dummy_action .. py:attribute:: exit_action .. py:attribute:: is_directed .. py:method:: make_actions_class() Returns the GraphActions class for this environment. :returns: A type of a subclass of GraphActions with environment-specific functionalities. .. py:method:: make_random_states(batch_shape, conditions = None, device = None) :abstractmethod: Optional method to return a batch of random graph states. :param batch_shape: Shape of the batch (int or tuple). :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to create the graph states on. :returns: A batch of random graph states. .. py:method:: make_states_class() Returns the GraphStates class for this environment. :returns: A type of a subclass of GraphStates with environment-specific functionalities. .. py:attribute:: num_edge_classes .. py:attribute:: num_node_classes .. py:method:: reset(batch_shape, random = False, sink = False, seed = None, conditions = None) Instantiates a batch of random, initial, or sink graph states. :param batch_shape: Shape of the batch (int or tuple). :param random: If True, initialize states randomly. :param sink: If True, initialize states as sink states ($s_f$). :param seed: (Optional) Random seed for reproducibility. :param conditions: (Optional) Tensor of shape (*batch_shape, condition_dim) containing the conditions. :returns: A batch of random, initial, or sink graph states. .. py:attribute:: s0 :type: torch_geometric.data.Data .. py:attribute:: sf :type: torch_geometric.data.Data .. py:method:: step(states, actions) :abstractmethod: Forward transition function of the graph environment. This method takes a batch of graph states and actions and returns a batch of next graph states. :param states: A batch of graph states. :param actions: A batch of graph actions. :returns: A batch of next graph states. .. py:data:: NonValidActionsError