gfn.gym ======= .. py:module:: gfn.gym .. autoapi-nested-parse:: This module contains all the environments implemented as Gym environments. Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/gfn/gym/bayesian_structure/index /autoapi/gfn/gym/bitSequence/index /autoapi/gfn/gym/bitSequenceNonAutoregressive/index /autoapi/gfn/gym/box/index /autoapi/gfn/gym/box_cartesian/index /autoapi/gfn/gym/chip_design/index /autoapi/gfn/gym/diffusion_sampling/index /autoapi/gfn/gym/discrete_ebm/index /autoapi/gfn/gym/graph_building/index /autoapi/gfn/gym/helpers/index /autoapi/gfn/gym/hypergrid/index /autoapi/gfn/gym/line/index /autoapi/gfn/gym/perfect_tree/index /autoapi/gfn/gym/set_addition/index Attributes ---------- .. autoapisummary:: gfn.gym.Box Classes ------- .. autoapisummary:: gfn.gym.BitSequence gfn.gym.BitSequencePlus gfn.gym.BoxCartesian gfn.gym.BoxPolar gfn.gym.ChipDesign gfn.gym.ConditionalHyperGrid gfn.gym.DiscreteEBM gfn.gym.GraphBuilding gfn.gym.GraphBuildingOnEdges gfn.gym.HyperGrid gfn.gym.Line gfn.gym.NonAutoregressiveBitSequence gfn.gym.PerfectBinaryTree gfn.gym.SetAddition Package Contents ---------------- .. py:class:: BitSequence(word_size = 4, seq_size = 120, n_modes = 60, temperature = 1.0, H = None, device_str = 'cpu', seed = 0, debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Append-only BitSequence environment. This environment represents a sequence of binary words and provides methods to manipulate and evaluate these sequences. The possible actions are adding binary words at once. Each binary word is represented as its decimal representation in both states and actions. .. attribute:: word_size The size of each binary word in the sequence. .. attribute:: seq_size The total number of digits of the sequence. .. attribute:: n_modes The number of unique modes in the sequence. .. attribute:: temperature The temperature parameter for reward calculation. .. attribute:: H A tensor used to create the modes. .. attribute:: device_str The device to run the computations on ("cpu" or "cuda"). .. attribute:: words_per_seq The number of words per sequence. .. attribute:: modes The set of modes written as binary. .. py:attribute:: H :value: None .. py:attribute:: States :type: type[BitSequenceStates] .. py:method:: _backward_step(states, actions) Perform a backward step in the environment by undoing the given actions to the current states. :param states: The current states of the environment. :param actions: The actions to be applied to the current states. :returns: The new states after performing the backward step. .. py:method:: _step(states, actions) Perform a step in the environment by applying the given actions to the current states. :param states: The current states of the environment. :param actions: The actions to be applied to the current states. :returns: The new states of the environment after applying the actions. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:method:: binary_to_integers(binary_tensor, k) :staticmethod: Convert a binary tensor to a tensor of integers. :param binary_tensor: A tensor containing binary values. The tensor must be of type int64. :param k: The number of bits in each integer. :returns: A tensor of integers obtained from the binary tensor. .. py:method:: create_test_set(k, seed = 0) Create a test set by altering k times each mode a random number of bits. Test set of size n_modes * k. :param k: Number of variations per mode. :param seed: Seed for reproducibility. If None, randomness is not fixed. :returns: The generated test set in the decimal representation. .. py:method:: hamming_distance(candidates, reference) :staticmethod: Compute the smallest edit distance from each candidate row to any reference row. :param candidates: Tensor of shape `(*batch_shape, length)`. :param reference: Tensor of shape `(n_ref, length)`. :returns: Tensor of shape `(*batch_shape)` containing the smallest edit distance for each candidate row. .. py:method:: integers_to_binary(tensor, k) :staticmethod: Convert a tensor of integers to their binary representation using k bits. :param tensor: A tensor containing integers. :param k: The number of bits to use for the binary representation of each integer. :returns: A tensor containing the binary representation of the input integers. .. py:method:: log_reward(final_states) Calculates the log-reward for the given final states. :param final_states: The final states for which to calculate the log-reward. :returns: The calculated log-reward. .. py:method:: make_modes_set(seed) Generates a set of unique mode sequences based on the predefined tensor H. :param seed: The seed for random number generation. :returns: A tensor containing the unique mode sequences. :raises ValueError: If the number of requested modes exceeds the number of possible unique sequences. .. py:method:: make_states_class() Creates a BitSequenceStates class implementation. :returns: A BitSequenceStates class implementation. .. py:attribute:: modes .. py:attribute:: n_actions :value: 17 .. py:attribute:: n_modes :type: int :value: 60 .. py:property:: n_states3 :type: int Returns the total number of states in the environment. .. py:property:: n_terminating_states :type: int Returns the number of terminating states. .. py:method:: reset(batch_shape = None, sink = False) Generates initial or sink states from batch_shape. :param batch_shape: The shape of the batch. If None, defaults to (1,). If an integer is provided, it is converted to a tuple. :param sink: If True, sink state is created. Defaults to False. :returns: The initial states of the environment after reset. .. py:method:: reward(final_states) Calculate the reward for the given final states. The reward is computed based on the Hamming distance between the binary representation of the final states and the predefined modes. The reward is then scaled using an exponential function with a temperature parameter. :param final_states: The final states for which the reward is to be calculated. :returns: The calculated reward for the given final states. .. py:attribute:: seq_size :type: int :value: 120 .. py:method:: states_from_tensor(tensor, length = None) Wraps the supplied Tensor in a States instance. :param tensor: The tensor of shape `state_shape` representing the states. :param length: The length of each state in the tensor. :returns: An instance of States. .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:attribute:: temperature :value: 1.0 .. py:property:: terminating_states :type: BitSequenceStates Returns all terminating states of the environment. .. py:method:: trajectory_from_terminating_states(terminating_states_tensor) Generate trajectories from terminating states. This works because the DAG is a tree in the append-only version of BitSequence. :param terminating_states_tensor: A tensor containing the terminating states from which to generate the trajectories. The shape of the tensor should be `(batch_size, words_per_seq)`. :returns: An object containing the generated trajectories. .. py:method:: true_dist(condition=None) Returns the true probability mass function of the reward distribution. .. py:attribute:: word_size :type: int :value: 4 .. py:attribute:: words_per_seq :type: int :value: 30 .. py:class:: BitSequencePlus(word_size = 4, seq_size = 120, n_modes = 60, temperature = 1.0, H = None, device_str = 'cpu', seed = 0) Bases: :py:obj:`BitSequence` Prepend-Append version of BitSequence env. This environment is similar to BitSequence, but allows to prepend and append words to the sequence. .. py:attribute:: H :value: None .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:method:: make_states_class() Creates a BitSequenceStates class implementation for BitSequencePlus. :returns: A BitSequenceStates class implementation with prepend-append mask logic. .. py:attribute:: modes .. py:attribute:: n_modes :type: int :value: 60 .. py:attribute:: seq_size :type: int :value: 120 .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:attribute:: temperature :value: 1.0 .. py:method:: trajectory_from_terminating_states(terminating_states_tensor) :abstractmethod: Generates trajectories from terminating states. Not implemented for this environment. :param terminating_states_tensor: A tensor of terminating states. :raises NotImplementedError: This method is not implemented for this environment. .. py:attribute:: word_size :type: int :value: 4 .. py:attribute:: words_per_seq :type: int :value: 30 .. py:data:: Box .. py:class:: BoxCartesian Bases: :py:obj:`gfn.gym.box.BoxPolar` Box environment with Cartesian per-dimension action validation. Inherits all behavior from :class:`BoxPolar` (init, step, backward_step, reward, log_partition, norm, make_random_states). Overrides only :meth:`is_action_valid` to use per-dimension bounds instead of polar norm constraints. Use with the Cartesian estimators/distributions in ``box_cartesian_utils.py``. .. seealso:: :class:`BoxPolar` for the original polar norm-based variant. .. py:method:: is_action_valid(states, actions, backward = False) Checks if the actions are valid (Cartesian per-dimension semantics). For Cartesian actions: - Forward from s0: action[i] >= 0 and action[i] <= 1 - Forward from non-s0: action[i] >= delta and state[i] + action[i] <= 1 - Backward: state[i] - action[i] >= 0 - Backward to s0: if all resulting dims < delta, action must equal state :param states: The current states. :param actions: The actions to be taken. :param backward: Whether the actions are backward actions. :returns: True if the actions are valid, False otherwise. .. py:class:: BoxPolar(delta = 0.1, R0 = 0.1, R1 = 0.5, R2 = 2.0, epsilon = 0.0001, device = 'cpu', debug = False) Bases: :py:obj:`gfn.env.Env` Box environment with polar (norm-based) action validation. Corresponds to the environment in Section 4.1 of https://arxiv.org/abs/2301.12594 Actions are 2D vectors whose L2 norm must equal delta (for non-s0 forward steps) or be at most delta (for the initial s0 step). Use with the polar estimators/distributions in ``box_polar_utils.py``. .. seealso:: :class:`~gfn.gym.box_cartesian.BoxCartesian` for a simpler per-dimension Cartesian variant. .. attribute:: delta The step size. .. attribute:: R0 The base reward. .. attribute:: R1 The reward for being outside the first box. .. attribute:: R2 The reward for being inside the second box. .. attribute:: epsilon A small value to avoid numerical issues. .. attribute:: device The device to use. :type: Literal["cpu", "cuda"] | torch.device .. py:attribute:: R0 :value: 0.1 .. py:attribute:: R1 :value: 0.5 .. py:attribute:: R2 :value: 2.0 .. py:method:: backward_step(states, actions) Backward step function for the Box environment. :param states: States object representing the current states. :param actions: Actions object representing the actions to be taken. :returns: The previous states as a States object. .. py:attribute:: delta :value: 0.1 .. py:attribute:: epsilon :value: 0.0001 .. py:method:: is_action_valid(states, actions, backward = False) Checks if the actions are valid (polar norm-based semantics). For polar actions: - Forward from s0: norm(action) <= delta - Forward from non-s0: norm(action) == delta (within tolerance) - Backward: state - action >= 0 component-wise - Backward to s0: if norm(state) < delta, action must equal state :param states: The current states. :param actions: The actions to be taken. :param backward: Whether the actions are backward actions. :returns: True if the actions are valid, False otherwise. .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generates random states tensor of shape (*batch_shape, 2). :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A States object with random states. .. py:method:: norm(x) :staticmethod: Computes the L2 norm of the input tensor along the last dimension. :param x: Input tensor of shape `(*batch_shape, 2)`. :returns: Normalized tensor of shape `batch_shape`. .. py:method:: reward(final_states) Reward is distance from the goal point. :param final_states: States object representing the final states. :returns: The reward tensor of shape `batch_shape`. .. py:method:: step(states, actions) Step function for the Box environment. :param states: States object representing the current states. :param actions: Actions object representing the actions to be taken. :returns: The next states as a States object. .. py:class:: ChipDesign(netlist_file = SAMPLE_NETLIST_FILE, init_placement = SAMPLE_INIT_PLACEMENT, std_cell_placer_mode = 'fd', wirelength_weight = 1.0, density_weight = 1.0, congestion_weight = 0.5, device = None, debug = False, singularity_image = None, cd_finetune = True, reward_norm = None, reward_temper = 1.0, reward_norm_samples = 100, cost_stats = None) Bases: :py:obj:`gfn.env.DiscreteEnv` GFlowNet environment for chip placement. The state is a vector of length `n_macros`, where `state[i]` is the grid cell location of the i-th macro to be placed. Unplaced macros have a location of -1. Actions are integers from `0` to `n_grid_cells - 1`, representing the grid cell to place the current macro on. Action `n_grid_cells` is the exit action. .. py:attribute:: States :type: type[ChipDesignStates] .. py:method:: __del__() .. py:method:: _apply_state_to_plc(state_tensor) Applies a single state tensor to the plc object. .. py:method:: _backward_step(states, actions) Wraps parent _backward_step and updates masks. .. py:method:: _estimate_cost_stats(n_samples) Estimates cost distribution from random placements. .. py:attribute:: _hard_macro_indices .. py:method:: _normalize_cost(cost) Applies reward normalization to a raw cost value. .. py:attribute:: _sorted_node_indices .. py:method:: _step(states, actions) Wraps parent _step and updates masks. .. py:method:: analytical_placer() Places standard cells using an analytical placer. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. .. py:attribute:: cd_finetune :value: True .. py:method:: close() Closes the PlacementCost subprocess to free resources. .. py:attribute:: congestion_weight :value: 0.5 .. py:attribute:: density_weight :value: 1.0 .. py:method:: log_reward(final_states) Computes the log reward of the final states. .. py:method:: make_states_class() Creates the ChipDesignStates class. .. py:attribute:: n_grid_cells .. py:attribute:: n_macros .. py:attribute:: plc .. py:method:: reset(batch_shape, random = False, sink = False, seed = None, conditions = None) Resets the environment and computes initial masks. .. py:attribute:: reward_norm :value: None .. py:attribute:: reward_temper :value: 1.0 .. py:attribute:: std_cell_placer_mode :value: 'fd' .. py:method:: step(states, actions) Performs a forward step in the environment. .. py:method:: update_masks(states) Updates the forward and backward masks of the states. .. py:attribute:: wirelength_weight :value: 1.0 .. py:class:: ConditionalHyperGrid(*args, **kwargs) Bases: :py:obj:`HyperGrid` HyperGrid environment with condition-aware rewards. Let condition 'c' be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape) .. py:attribute:: _log_partition_cache :type: dict[torch.Tensor, float] .. py:attribute:: _max_reward :type: float .. py:attribute:: _original_reward_fn .. py:attribute:: _true_dist_cache :type: dict[torch.Tensor, torch.Tensor] .. py:attribute:: condition_dim :type: int :value: 1 .. py:attribute:: is_conditional :type: bool :value: True .. py:method:: log_partition(condition) Compute the log partition for the given condition. :param condition: The condition to compute the log partition for. condition.shape should be (1,) :returns: The log partition function, as a float. .. py:method:: reward(states) Compute rewards for the conditional environment. A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original :param states: The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape) :returns: A tensor of shape (*batch_shape,) containing the rewards. .. py:method:: sample_conditions(batch_shape) Sample conditions for the environment. .. py:method:: true_dist(condition) Compute the true distribution for the given condition. :param condition: The condition to compute the true distribution for. :param condition.shape should be: :type condition.shape should be: 1, :returns: The true distribution for the given condition as a 1-dimensional tensor. .. py:class:: DiscreteEBM(ndim, energy = None, alpha = 1.0, device = 'cpu', debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Environment for discrete energy-based models. This environment is based on the paper https://arxiv.org/pdf/2202.01361.pdf. The states are represented as 1d tensors of length `ndim` with values in `{-1, 0, 1}`. `s0` is empty (represented as -1), so `s0=[-1, -1, ..., -1]`. An action corresponds to replacing a -1 with a 0 or a 1. Action `i` in `[0, ndim - 1]` corresponds to replacing `s[i]` with 0. Action `i` in `[ndim, 2 * ndim - 1]` corresponds to replacing `s[i - ndim]` with 1. The last action is the exit action that is only available for complete states (those with no -1). .. attribute:: ndim Dimension D of the sampling space `{0, 1}^D`. :type: int .. attribute:: energy Energy function of the EBM. :type: EnergyFunction .. attribute:: alpha Interaction strength the EBM. :type: float .. py:attribute:: States :type: type[gfn.states.DiscreteStates] .. py:property:: all_states :type: gfn.states.DiscreteStates Returns all possible states of the environment. .. py:attribute:: alpha :value: 1.0 .. py:method:: backward_step(states, actions) Performs a backward step. In this env, states are n-dim vectors. `s0` is empty (represented as -1), so `s0=[-1, -1, ..., -1]`, each action is replacing a -1 with either a 0 or 1. Action `i` in `[0, ndim-1]` is replacing `s[i]` with 0, whereas action `i` in `[ndim, 2*ndim-1]` corresponds to replacing `s[i - ndim]` with 1. A backward action asks "what index should be set back to -1", hence the fmod to enable wrapping of indices. :param states: The current states. :param actions: The actions to be undone. :returns: The previous states. .. py:attribute:: energy :type: EnergyFunction :value: None .. py:method:: get_states_indices(states) Given that each state is of length `ndim` with values in `{-1, 0, 1}`, there are `3**ndim` states, which we can label from `0` to `3**ndim - 1`. The easiest way to map each state to a unique integer is to consider the state as a number in base 3, where each digit can be in `{0, 1, 2}`. We thus need to shift this number by 1 so that `{-1, 0, 1} -> {0, 1, 2}`. :param states: DiscreteStates object representing the states. :returns: The states indices as tensor of shape `(*batch_shape)`. .. py:method:: get_terminating_states_indices(states) Given that each terminating state is of length `ndim` with values in `{0, 1}`, there are `2**ndim` terminating states, which we can label from `0` to `2**ndim - 1`. The easiest way to map each state to a unique integer is to consider the state as a number in base 2. :param states: DiscreteStates object representing the states. :returns: The indices of the terminating states as tensor of shape `(*batch_shape)`. .. py:method:: is_exit_actions(actions) Determines if the actions are exit actions. :param actions: tensor of actions of shape `(*batch_shape, *action_shape)`. :returns: Tensor of booleans of shape `(*batch_shape)`. .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: log_reward(final_states) The energy weighted by alpha is our log reward. :param final_states: DiscreteStates object representing the final states. :returns: The log reward as tensor of shape `(*batch_shape)`. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generates random states tensor of shape `(*batch_shape, ndim)`. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `DiscreteStates` object with random states. .. py:method:: make_states_class() Returns the DiscreteStates class for the DiscreteEBM environment. .. py:property:: n_states :type: int Returns the number of states in the environment. .. py:property:: n_terminating_states :type: int Returns the number of terminating states in the environment. .. py:attribute:: ndim .. py:method:: reward(final_states) Computes the reward for a batch of final states. :param final_states: A batch of final states. :returns: A tensor of rewards. .. py:method:: step(states, actions) Performs a step. :param states: States object representing the current states. :param actions: Actions object representing the actions to be taken. :returns: The next states as a `States` object. .. py:property:: terminating_states :type: gfn.states.DiscreteStates Returns all terminating states of the environment. .. py:method:: true_dist(condition=None) Returns the true probability mass function of the reward distribution. .. py:class:: GraphBuilding(num_node_classes, num_edge_classes, state_evaluator, is_directed = True, max_nodes = None, device = 'cpu', s0 = None, sf = None, debug = False) Bases: :py:obj:`gfn.env.GraphEnv` Environment for incrementally building graphs. This environment allows constructing graphs by: - Adding nodes of a given class - Adding edges of a given class between existing nodes - Terminating construction (EXIT) .. attribute:: num_node_classes The number of node classes. .. attribute:: num_edge_classes The number of edge classes. .. attribute:: state_evaluator A callable that computes rewards for final states. .. attribute:: is_directed Whether the graph is directed. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to undo. :returns: The previous states. .. py:method:: is_action_valid(states, actions, backward = False) Check if actions are valid for the given states. :param states: Current graph states. :param actions: Actions to validate. :param backward: Whether this is a backward step. :returns: True if all actions are valid, False otherwise. .. py:method:: make_actions_class() Returns the GraphActions class for this environment. :returns: A type of a subclass of GraphActions with environment-specific functionalities. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generates random states. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `GraphStates` object with random states. .. py:method:: make_states_class() Creates a `GraphStates` class for this environment. .. py:attribute:: max_nodes :value: None .. py:method:: reward(final_states) The environment's reward given a state. :param final_states: A batch of final states. :returns: A tensor of shape `(batch_size,)` containing the rewards. .. py:attribute:: state_evaluator .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:class:: GraphBuildingOnEdges(n_nodes, state_evaluator, directed, device, debug = False) Bases: :py:obj:`GraphBuilding` Environment for building graphs edge by edge with discrete action space. The environment supports both directed and undirected graphs. In each state, the policy can: 1. Add an edge between existing nodes. 2. Use the exit action to terminate graph building. The action space is discrete, with size: - For directed graphs: `n_nodes^2 - n_nodes + 1` (all possible directed edges + exit). - For undirected graphs: `(n_nodes^2 - n_nodes)/2 + 1` (upper triangle + exit). .. attribute:: n_nodes The number of nodes in the graph. :type: int .. attribute:: n_possible_edges The number of possible edges. :type: int .. py:method:: is_action_valid(states, actions, backward = False) Checks if the actions are valid. :param states: The current states. :param actions: The actions to validate. :param backward: Whether the actions are backward actions. :returns: `True` if the actions are valid, `False` otherwise. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Makes a batch of random graph states with fixed number of nodes. :param batch_shape: Shape of the batch dimensions. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `GraphStates` object containing random graph states. .. py:method:: make_states_class() Creates a `GraphStates` class for this environment. .. py:attribute:: n_nodes .. py:class:: HyperGrid(ndim = 2, height = 8, reward_fn_str = 'original', reward_fn_kwargs = None, device = 'cpu', calculate_partition = False, store_all_states = False, debug = False, validate_modes = True, mode_stats = 'none', mode_stats_samples = 20000) Bases: :py:obj:`gfn.env.DiscreteEnv` HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in `{0, 1, ..., height - 1}`. .. attribute:: ndim The dimension of the grid. .. attribute:: height The height of the grid. .. attribute:: reward_fn The reward function. .. attribute:: calculate_partition Whether to calculate the log partition function. .. attribute:: store_all_states Whether to store all states. .. attribute:: validate_modes Whether to check that at least one state reaches the mode threshold at init; raises if not. .. attribute:: mode_stats One of {"none", "approx", "exact"}. If not "none", computes (exact or approximate) `n_modes` and `n_mode_states`. "exact" requires `store_all_states=True` and enumerates all states. .. attribute:: mode_stats_samples Number of random samples when `mode_stats="approx"`. .. py:attribute:: States :type: type[gfn.states.DiscreteStates] .. py:attribute:: _all_states_tensor :value: None .. py:method:: _enumerate_all_states_tensor(batch_size = 20000) Enumerate all grid states, optionally storing them and computing log Z. Iterates over the full Cartesian product ``{0, ..., H-1}^D`` in batches (via multiprocessing) to avoid materializing all ``H^D`` states at once. :param batch_size: Number of states per batch. .. py:method:: _exists_bitwise_xor(thr) Deterministic feasibility check for ``BitwiseXORReward``. Builds the combined GF(2) system [trunk; selector→0; head_0]·b = [c_trunk; 0; c_head_0] for rule 0 and verifies consistency. This works uniformly for n_rules=1 (k_select=0, selector empty) and for K-rule (per-rule coverage was already verified at __init__ by _validate_rule_coverage; this re-checks rule 0 as a defense-in-depth). Feasibility of this combined GF(2) system is necessary and sufficient for a mode to exist; no random sampling is required. Presets use power-of-two heights so every feasible bit-assignment is a valid state (raw coord < height). .. py:method:: _exists_conditional_multiscale(thr) Constructive existence check for ConditionalMultiScaleReward. With filter_shift=[0,...,0] (default) the all-zeros state is always a mode: every per-tier filter passes 0 since (0 + 0) mod B = 0 < f. With non-zero filter_shift, we try a few "all-same-v" candidate states chosen so the MSD passes tier 0; one of them typically passes all deeper tiers when the per-rule shift_coeffs map zero lower digits to zero, leaving the constant filter_shift[t] as the only contribution. .. py:method:: _exists_cosine(thr) Analytic upper-bound check for ``CosineReward``. Idea: - The per-dimension factor is ``(cos(50·ax) + 1) · N(0,1)(5·ax)`` with ax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value ``m``. - The full reward upper bound is ``R0 + m^D * R1``. If this is at least the mode target and the given threshold, a mode-level state must exist. - We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by ``mode_gamma``. .. py:method:: _exists_fallback_random(thr) Random sampling fallback. Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids. .. py:method:: _exists_multiplicative_coprime(thr) Number-theoretic constructive check for ``MultiplicativeCoprimeReward``. For each rule, factors the rule's target LCM over allowed primes, tries permutations of prime-to-active-dim assignments, and checks coprime + grid-bound + selector-match. Returns True iff at least one rule has a witness state whose selector maps back to that rule's index AND whose reward reaches the mode threshold. The reward shifts raw coords by +1 internally (raw 0 → internal 1), so witness states are constructed in raw space as ``p**exp - 1`` per active dim, with coprime pair checks evaluated on the post-shift internal values. At n_rules=1 the selector is trivially 0 and only rule 0 is tried, recovering the legacy behavior. .. py:method:: _exists_original_or_deceptive(thr) Constructive check for ``OriginalReward`` and ``DeceptiveReward``. Intuition: - These rewards form rings/bands around the center when each coordinate is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center. - We translate those fractional band boundaries into integer indices via small inside/outside nudges (using ``EPS_INDEX_CMP``) and test one candidate index from any non-empty feasible interval. - If the reward at that candidate exceeds the threshold (with ``EPS_REWARD_CMP`` tolerance), we return True. .. py:method:: _exists_random_or_corrupted(thr) Check for UniformRandomReward or CorruptedReward. For UniformRandomReward the probe budget is sized so that P(miss all modes | at least one mode exists) < 1e-9, using n = ceil(log(1e-9) / log(1 - mode_prob)). For CorruptedReward a fixed budget of 10 000 is used (mode density is approximately preserved by the promotion/demotion calibration). A seeded generator derived from the reward seed and grid shape makes the result reproducible across calls with the same configuration. .. py:method:: _exists_sparse(thr) Constructive check for ``SparseReward``. This reward assigns positive mass only to a finite set of target configurations. When ``H>=2`` and ``D>=1``, a known target is the zero vector except for certain coordinates fixed at 1 or ``H-2``. We probe a canonical target and confirm the threshold is not above its reward. .. py:method:: _generate_combinations_in_batches(ndim, max_val, batch_size) Yield batches of the Cartesian product {0, ..., max_val}^ndim. Uses multiprocessing to avoid materializing the full product (size ``(max_val+1)^ndim``) in memory. Workers are created via the spawn start method and execute the module-level :func:`_hypergrid_worker` function so the call is safe inside MPI ranks and CUDA contexts (see the start-method comment near the top of this file). Pool size is capped at ``MAX_POOL_WORKERS`` because larger pools just multiply per-rank fork/spawn overhead without shrinking the per-task work — and a 64-core node hosting many co-located MPI ranks can otherwise blow up to thousands of worker processes simultaneously. :param ndim: Number of dimensions (tuple length). :param max_val: Maximum coordinate value (inclusive). :param batch_size: Number of tuples per batch. :Yields: A list of tuples for each batch. .. py:method:: _get_states_indices_bigint(states_raw) Compute canonical indices using arbitrary-precision Python ints. Used by :meth:`get_states_indices` when ``height ** ndim > 2 ** 63`` and the int64 path would overflow. Vectorized over the (potentially large) batch dimension via numpy object-dtype broadcasting: the inner Python loop iterates only over the small feature dimension ``ndim``, and each ``k * h + col`` operation dispatches a single C-level loop over all rows that calls Python ``int.__mul__`` / ``int.__add__`` per element. This is a few times faster than a nested Python loop while still preserving arbitrary-precision correctness. Returns a numpy ``object`` array of shape ``states_raw.shape[:-1]`` containing one Python ``int`` per state. .. py:attribute:: _log_partition :value: None .. py:method:: _mode_reward_threshold() Returns the reward threshold used to define a mode. By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward. .. py:attribute:: _mode_stats_kind :type: str :value: 'none' .. py:method:: _modes_exist_quick_check() Lightweight check that a mode-level state exists. In simple terms, this answers: "Is there at least one state whose reward reaches the mode threshold?" without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly and checks against the threshold. 2) Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold. 3) As a last resort, it samples a small batch of random states. .. py:method:: _modes_exist_quick_check_info() Same as _modes_exist_quick_check but returns (ok, message). .. py:attribute:: _n_mode_states_estimate :type: float | None :value: None .. py:attribute:: _n_mode_states_exact :type: int | None :value: None .. py:method:: _solve_gf2_has_solution(A, c) :staticmethod: Return True if A x = c over GF(2) has at least one solution. Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (``0 = 1``) indicates inconsistency. .. py:method:: _solve_gf2_witness(A, c, n_vars) :staticmethod: Return a witness solution to A·b = c over GF(2), or None if none exists. b has length n_vars. A is reduced via Gaussian elimination; free variables are set to 0. .. py:attribute:: _true_dist :value: None .. py:method:: all_indices() Generate all possible indices for the grid. :returns: A list of all possible indices for the grid. .. py:property:: all_states :type: gfn.states.DiscreteStates | None Returns a tensor of all hypergrid states as a `DiscreteStates` instance. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:attribute:: calculate_partition :value: False .. py:method:: get_states_indices(states) Get the canonical ordering indices for a batch of states. Returns one canonical index per state computed from the base-``height`` encoding ``sum(s[j] * height^(ndim-1-j))``. The maximum index is ``height^ndim - 1``. - **Safe regime** (``height ** ndim <= 2 ** 63``): the index fits in signed int64 and we return a ``torch.Tensor`` of shape ``batch_shape`` with dtype ``torch.int64`` (the historical behaviour). - **Overflow regime** (``height ** ndim > 2 ** 63``): the index would overflow int64 and silently wrap, producing collisions between distinct states (a real bug we hit at e.g. ndim=10, height=128 where ``128**10 == 2**70``). In this regime we fall back to per-row Python ``int`` arithmetic and return a ``numpy.ndarray`` of dtype ``object`` containing arbitrary-precision Python ints. Each element is a unique, hashable canonical index. The two return types support the same downstream usages we care about (``set(...tolist())`` for mode tracking, boolean masking with ``[mask]`` after converting the mask to numpy if needed). Code paths that need an ``int64`` tensor for tensor indexing (e.g. ``EnumPreprocessor``) implicitly require the safe regime — they'll see the numpy fallback and fail loudly, which is the correct behavior because such grids are too large to enumerate anyway. :param states: The states to get the indices of. :returns: Indices in canonical ordering. ``torch.Tensor[int64]`` of shape ``batch_shape`` in the safe regime; ``np.ndarray[object]`` of shape ``batch_shape`` containing Python ints in the overflow regime. .. py:method:: get_terminating_states_indices(states) Get the indices of the terminating states in the canonical ordering. See :meth:`get_states_indices` for the return-type contract: a ``torch.Tensor[int64]`` for grids small enough to fit in 62 bits, or a ``numpy.ndarray[object]`` of Python ints for larger grids that would otherwise overflow. :param states: The states to get the indices of. :returns: The indices of the terminating states in the canonical ordering. .. py:attribute:: height :value: 8 .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Creates a batch of random states. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `DiscreteStates` object with random states. .. py:method:: make_states_class() Returns the DiscreteStates class for the HyperGrid environment. .. py:method:: mode_mask(states) Boolean mask indicating which states are in a mode. A state is flagged as mode if its reward is greater-or-equal to the threshold based on `reward_fn_kwargs` (R0+R1+R2 by default). .. py:method:: modes_found(states) Returns the set of canonical state indices for mode states in the batch. Each mode state is identified by its unique canonical index (from ``get_states_indices``), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions. .. py:property:: n_mode_states :type: int | float | None Number of states inside a mode (exact, approx, or None). - If mode_stats="exact", returns an exact integer count. - If mode_stats="approx", returns a floating-point estimate. - If store_all_states is True (but mode_stats was "none"), computes on demand from all_states. - Otherwise, returns None. .. py:property:: n_modes :type: int | float | None Returns the total number of mode states for this environment. Equivalent to ``n_mode_states``. Each individual grid cell whose reward meets the mode threshold counts as one mode. .. py:property:: n_states :type: int Returns the number of states in the environment. .. py:property:: n_terminating_states :type: int Returns the number of terminating states in the environment. .. py:attribute:: ndim :value: 2 .. py:method:: reward(states) Computes the reward for a batch of final states. In the normal setting, the reward is: `R(s) = R_0 + 0.5 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.25, 0.5] \right) + 2 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.3, 0.4) \right)` :param final_states: The final states. :returns: The reward of the final states. .. py:attribute:: reward_fn .. py:attribute:: reward_fn_kwargs :value: None .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:attribute:: store_all_states :value: False .. py:property:: terminating_states :type: gfn.states.DiscreteStates | None Returns all terminating states of the environment. .. py:method:: true_dist(condition=None) Returns the pmf over all states in the hypergrid. .. py:class:: Line(mus, sigmas, init_value, n_sd = 4.5, n_steps_per_trajectory = 5, device = 'cpu', debug = False) Bases: :py:obj:`gfn.env.Env` Mixture of Gaussians Line environment. .. attribute:: mus The means of the Gaussians. .. attribute:: sigmas The standard deviations of the Gaussians. .. attribute:: n_sd The number of standard deviations to consider for the bounds. .. attribute:: n_steps_per_trajectory The number of steps per trajectory. .. attribute:: mixture The mixture of Gaussians. .. attribute:: init_value The initial value of the state. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:attribute:: init_value .. py:method:: is_action_valid(states, actions, backward = False) Checks if the actions are valid. :param states: The current states. :param actions: The actions to check. :param backward: Whether to check for backward actions. :returns: `True` if the actions are valid, `False` otherwise. .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: log_reward(final_states) Computes the log reward of the environment. :param final_states: The final states of the environment. :returns: The log reward. .. py:attribute:: mixture .. py:attribute:: mus .. py:attribute:: n_sd :value: 4.5 .. py:attribute:: n_steps_per_trajectory :value: 5 .. py:attribute:: sigmas .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:class:: NonAutoregressiveBitSequence(word_size = 1, seq_size = 4, n_modes = 2, reward_exponent = 2.0, H = None, device_str = 'cpu', seed = 0, debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Non-autoregressive BitSequence environment. In this environment, the agent constructs a binary sequence by placing words at arbitrary positions. Each action specifies both which position to fill and which word value to place there. The episode ends when all positions are filled. The reward is based on the minimum Hamming distance (computed at the bit level) between the completed sequence and a set of target "mode" sequences. :param word_size: Number of bits per word (e.g., 1 for single-bit actions). :param seq_size: Total number of bits in the sequence. Must be divisible by ``word_size``. :param n_modes: Number of target mode sequences. :param reward_exponent: Controls reward sharpness. Higher values make the reward more peaked around the modes. :param H: Optional tensor of shape ``(n_modes, seq_size)`` specifying the target modes in binary. If None, modes are generated randomly using block patterns. :param device_str: Device to use (``"cpu"`` or ``"cuda"``). :param seed: Random seed for mode generation. :param debug: If True, enable runtime guards (not compile-friendly). .. attribute:: word_size Number of bits per word. .. attribute:: seq_size Total number of bits. .. attribute:: words_per_seq Number of word positions (``seq_size // word_size``). .. attribute:: n_words Number of possible word values (``2 ** word_size``). .. attribute:: n_modes Number of target modes. .. attribute:: reward_exponent Reward sharpness parameter. .. attribute:: modes Target mode sequences as a binary tensor of shape ``(n_modes, seq_size)``. .. rubric:: Example >>> env = NonAutoregressiveBitSequence(word_size=1, seq_size=4, n_modes=2) >>> # Action space: 4 positions * 2 word values + 1 exit = 9 actions >>> env.n_actions 9 >>> # State shape: 4 word positions >>> env.s0 tensor([-1, -1, -1, -1]) .. py:attribute:: H :value: None .. py:attribute:: States :type: type[NonAutoregressiveBitSequenceStates] .. py:method:: _decode_action(action) Decode a flat action index into (position, word) pair. :param action: Action tensor of shape ``(*batch_shape, 1)``. :returns: Tuple of (position, word) tensors, each of shape ``(*batch_shape, 1)``. .. py:method:: _integers_to_binary(tensor, k) :staticmethod: Convert a tensor of word integers to their binary representation. :param tensor: Integer tensor of shape ``(*batch_shape, words_per_seq)`` with values in ``{0, ..., 2^k - 1}``. :param k: Number of bits per word. :returns: Binary tensor of shape ``(*batch_shape, words_per_seq * k)`` with values in ``{0, 1}``. .. py:method:: _make_modes(seed, device) Generate target mode sequences in binary representation. If ``H`` is provided, it is used directly as the modes. Otherwise, modes are constructed by randomly combining 8-bit block patterns, following the procedure from the Trajectory Balance paper. :param seed: Random seed. :param device: Device to place the modes tensor on. :returns: Binary tensor of shape ``(n_modes, seq_size)`` with values in {0, 1}. .. py:method:: _min_hamming_distance(candidates, references) :staticmethod: Compute minimum Hamming distance from each candidate to any reference. :param candidates: Binary tensor of shape ``(*batch_shape, seq_size)``. :param references: Binary tensor of shape ``(n_refs, seq_size)``. :returns: Tensor of shape ``(*batch_shape,)`` with the minimum distance. .. py:method:: backward_step(states, actions) Undo a word placement by clearing the position back to -1. The backward action has the same encoding as the forward action: ``action = position * n_words + word``. The word component is used to identify which position to clear. :param states: Current states. :param actions: Backward actions to undo. :returns: Previous states with the specified positions cleared. .. py:method:: log_reward(final_states) Compute log-reward based on Hamming distance to nearest mode. The log-reward is: ``log R(x) = -reward_exponent * min_d(x, modes) / seq_size`` where ``min_d`` is the minimum bit-level Hamming distance between the completed sequence and any target mode. :param final_states: Terminal states with all positions filled. :returns: Log-reward tensor of shape ``(*batch_shape,)``. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Generate random partially-filled states. Each position is independently either unfilled (-1) or filled with a random word value. :param batch_shape: Shape of the batch. :param conditions: Optional conditions tensor. :param device: Device to use. :param debug: If True, enable debug mode. :returns: Random states. .. py:method:: make_states_class() Create the States class with environment-specific constants. .. py:attribute:: modes .. py:attribute:: n_modes_count :value: 2 .. py:property:: n_terminating_states :type: int Total number of possible terminal states. .. py:attribute:: n_words :value: 2 .. py:method:: reward(final_states) Compute reward as ``exp(log_reward)``. :param final_states: Terminal states. :returns: Reward tensor of shape ``(*batch_shape,)``. .. py:attribute:: reward_exponent :value: 2.0 .. py:attribute:: seq_size :value: 4 .. py:method:: step(states, actions) Place a word at the specified position. The action encodes ``(position, word)`` as a flat index: ``action = position * n_words + word``. :param states: Current states. :param actions: Actions encoding (position, word) pairs. :returns: Next states with the specified positions filled. .. py:property:: terminating_states :type: NonAutoregressiveBitSequenceStates Enumerate all terminal states (only feasible for small environments). .. py:method:: true_dist(condition=None) Compute the true reward distribution over all terminal states. .. py:attribute:: word_size :value: 1 .. py:attribute:: words_per_seq :value: 4 .. py:class:: PerfectBinaryTree(reward_fn, depth = 4, device = None, debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Perfect Tree Environment. This environment is a perfect binary tree, where there is a bijection between trajectories and terminating states. Nodes are represented by integers, starting from 0 for the root. States are represented by a single integer tensor corresponding to the node index. Actions are integers: 0 (left child), 1 (right child), 2 (exit). e.g.: 0 (root) / \ 1 2 / \ / \ 3 4 5 6 / \ / \ / \ / \ 7 8 9 10 11 12 13 14 (terminating states if depth=3) Recommended preprocessor: `OneHotPreprocessor`. .. attribute:: reward_fn A function that computes the reward for a given state. :type: Callable .. attribute:: depth The depth of the tree. :type: int .. attribute:: branching_factor The branching factor of the tree. :type: int .. attribute:: n_actions The number of actions. :type: int .. attribute:: n_nodes The number of nodes in the tree. :type: int .. attribute:: transition_table A dictionary that maps (state, action) to the next state. :type: dict .. attribute:: inverse_transition_table A dictionary that maps (state, action) to the previous state. :type: dict .. attribute:: term_states The terminating states. :type: DiscreteStates .. py:attribute:: States :type: type[gfn.env.DiscreteStates] .. py:method:: _build_tree() Builds the tree and the transition tables. :returns: A tuple containing the transition table, the inverse transition table, and the terminating states. .. py:property:: all_states :type: gfn.env.DiscreteStates Returns all the states of the environment. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:attribute:: branching_factor :value: 2 .. py:attribute:: depth :value: 4 .. py:method:: get_states_indices(states) Returns the indices of the states. :param states: The states to get the indices of. :returns: The indices of the states. .. py:method:: make_states_class() Returns the DiscreteStates class for the PerfectBinaryTree environment. .. py:attribute:: n_actions :value: 3 .. py:attribute:: n_nodes :value: 31 .. py:method:: reward(final_states) Computes the reward for a batch of final states. :param final_states: The final states. :returns: The reward of the final states. .. py:attribute:: reward_fn .. py:attribute:: s0 .. py:attribute:: sf .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:property:: terminating_states :type: gfn.env.DiscreteStates Returns the terminating states of the environment. .. py:class:: SetAddition(n_items, max_items, reward_fn, fixed_length = False, device = None, debug = False) Bases: :py:obj:`gfn.env.DiscreteEnv` Append only MDP, similarly to what is described in Remark 8 of Shen et al. 2023 [Towards Understanding and Improving GFlowNet Training](https://proceedings.mlr.press/v202/shen23a.html) The state is a binary vector of length `n_items`, where 1 indicates the presence of an item. Actions are integers from 0 to `n_items - 1` to add the corresponding item, or `n_items` to exit. Adding an existing item is invalid. The trajectory must end when `max_items` are present. Recommended preprocessor: `IdentityPreprocessor`. .. attribute:: n_items The number of items in the set. :type: int .. attribute:: max_items The maximum number of items that can be added to the set. :type: int .. attribute:: reward_fn The reward function. :type: Callable .. attribute:: fixed_length Whether the trajectories have a fixed length. :type: bool .. py:attribute:: States :type: type[gfn.env.DiscreteStates] .. py:property:: all_states :type: gfn.env.DiscreteStates Returns all the states of the environment. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:attribute:: fixed_length :value: False .. py:method:: get_states_indices(states) Returns the indices of the states. :param states: The states to get the indices of. :returns: The indices of the states. .. py:method:: make_states_class() Returns the DiscreteStates class for the SetAddition environment. .. py:attribute:: max_traj_len .. py:attribute:: n_items .. py:method:: reward(final_states) Computes the reward for a batch of final states. :param final_states: The final states. :returns: The reward of the final states. .. py:attribute:: reward_fn .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:property:: terminating_states :type: gfn.env.DiscreteStates Returns the terminating states of the environment.