gfn.gym.hypergrid ================= .. py:module:: gfn.gym.hypergrid .. autoapi-nested-parse:: Adapted from https://github.com/Tikquuss/GflowNets_Tutorial Attributes ---------- .. autoapisummary:: gfn.gym.hypergrid._MAX_POOL_WORKERS gfn.gym.hypergrid.logger Classes ------- .. autoapisummary:: gfn.gym.hypergrid.BitwiseXORReward gfn.gym.hypergrid.ConditionalHyperGrid gfn.gym.hypergrid.ConditionalMultiScaleReward gfn.gym.hypergrid.CorruptedReward gfn.gym.hypergrid.CosineReward gfn.gym.hypergrid.DeceptiveReward gfn.gym.hypergrid.GridReward gfn.gym.hypergrid.HyperGrid gfn.gym.hypergrid.MultiplicativeCoprimeReward gfn.gym.hypergrid.OriginalReward gfn.gym.hypergrid.SparseReward gfn.gym.hypergrid.UniformRandomReward Functions --------- .. autoapisummary:: gfn.gym.hypergrid._first_k_dims gfn.gym.hypergrid._gf2_random_fullrank gfn.gym.hypergrid._gf2_rank gfn.gym.hypergrid._hypergrid_worker gfn.gym.hypergrid._preset_seed gfn.gym.hypergrid._state_hash_uniform gfn.gym.hypergrid.get_bitwise_xor_presets gfn.gym.hypergrid.get_conditional_multiscale_presets gfn.gym.hypergrid.get_corrupted_presets gfn.gym.hypergrid.get_cosine_presets gfn.gym.hypergrid.get_deceptive_presets gfn.gym.hypergrid.get_multiplicative_coprime_presets gfn.gym.hypergrid.get_original_presets gfn.gym.hypergrid.get_reward_presets gfn.gym.hypergrid.get_sparse_presets gfn.gym.hypergrid.get_uniform_random_presets gfn.gym.hypergrid.lcm gfn.gym.hypergrid.lcm_multiple gfn.gym.hypergrid.smallest_multiplier_to_integers Module Contents --------------- .. py:class:: BitwiseXORReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Tiered, compositional reward based on bitwise XOR/parity constraints. Curriculum motivation — rule reuse: This reward tests whether a GFlowNet can learn a global algebraic rule (GF(2) parity) and reuse it across tiers of increasing strictness. Unlike the other compositional rewards, modes are NOT spatially concentrated near the origin — they are distributed non-locally across the grid according to algebraic structure. This is intentional: it probes the model's ability to learn abstract, non-spatial compositionality. The curriculum operates through constraint accumulation: tier 0 applies few parity checks (many modes, easy to discover), tier 1 adds more checks (fewer modes, same rule type), etc. A model that learns the parity computation at tier 0 can reuse that same computation to satisfy tier 1+ constraints, providing a form of compositional transfer for long-horizon credit assignment. This class implements the "Bitwise/XOR fractal" environment family: where tiers progressively constrain bit-planes across a subset of dimensions via linear parity checks over GF(2). It supports easy sharding by high-bit prefixes, and difficulty control by adjusting which bit-planes and how many dimensions are constrained per tier. GF(2) is the finite field with two elements {0, 1}, where addition and multiplication are performed modulo 2. In this context, vector addition is equivalent to bitwise XOR, and matrix-vector products (A @ b) are evaluated entrywise modulo 2. Reward form: R(s) = R0 + Σ_t tier_weights[t] · 1[ state satisfies all constraints up to tier t ] Key kwargs (with reasonable defaults): - R0: float, base reward (default 0.0) - tier_weights: list[float], strictly increasing weights for each tier - dims_constrained: Optional[list[int]] subset of dims to constrain (default: all dims) - bits_per_tier: list[tuple[int,int]]; for each tier t, inclusive bit range (low_bit, high_bit). Example: [(0,5), (0,7), (0,9)]. - parity_checks: Optional[list[dict]]; per tier, optional parity system: Each entry may contain: { "A": IntTensor[num_checks, m], "c": IntTensor[num_checks] } where m = len(dims_constrained). Constraints apply identically to every bit-plane specified for that tier: A @ b(mod2) == c, where b are the bit values across constrained dimensions at the tested bit-plane. If omitted for a tier, a single even-parity check across all constrained dims is used by default: sum(b) mod 2 == 0. Difficulty presets align with step ranges by controlling the highest bit used and the number of constrained dimensions. Typical distance from origin for valid modes scales roughly like (constrained_dims · 2^{highest_bit}). K-rule structure (n_rules >= 1): - Trunk: the per-tier parity stack above, shared across all rules. - Selector: a fixed GF(2) matrix S of shape (k_select, M*B) projects bits to a rule index r = pack(S·b mod 2) ∈ [0, n_rules). Here k_select = ceil(log2(n_rules)); for n_rules=1, k_select=0 and r is always 0. - Head: per-rule parity matrix H_r of shape (head_check_count, M), applied at every bit-plane in head_bit_range. Mode iff trunk passes AND H_r·b == c_r at the head's bit-planes. Reward: R = R0 + Σ_t w_t·1[trunk_0..t pass] + head_weight · 1[trunk all pass ∧ head_{σ(b)} pass] At n_rules=1 with head_check_count=0 and head_weight=0 (defaults), the head is empty and the K-rule code path collapses to the legacy reward bit-exactly. Total mode count is invariant in n_rules when it's a power of 2 — the selector adds k_select bits that partition the trunk-passing space, and per-rule head adds head_check_count·n_head_bits bits per coset. Comparison with other compositional rewards: - MultiplicativeCoprimeReward: number-theoretic (prime factorization); knowledge composition — learning prime structure enables coprimality and LCM constraints at higher tiers. - ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints. - This class: GF(2) linear algebra on bit-planes; rule reuse — the same parity check type is applied with increasing strictness per tier. Modes are non-local (algebraic, not spatial). .. py:attribute:: R0 :type: float .. py:attribute:: _B .. py:attribute:: _RULE_SEED_STRIDE :type: int :value: 1000003 .. py:method:: __call__(states_tensor) .. py:method:: _apply_parity_checks(bits_plane, tier_idx) Apply GF(2) linear parity checks at a single bit-plane. bits_plane: (..., m) with m=len(dims_constrained), integer in {0,1}. Returns mask (...,) bool. .. py:attribute:: _bit_positions .. py:attribute:: _dim_idx .. py:method:: _even_parity_mask(bits) bits: (..., m) int/bool -> returns (...,) bool for even parity. .. py:attribute:: _head_A_per_rule .. py:attribute:: _head_c_per_rule .. py:method:: _per_rule_mode_counts() Bit-config mode count per rule (without unconstrained-dim factor). For rule k, returns 2^(M·B − rank([trunk; selector; head_k])) when the combined system is consistent, else 0. Inconsistency means the rule is unreachable — its bit-config mode count is 0. .. py:attribute:: _select_powers .. py:attribute:: _tier_check_counts :value: [] .. py:attribute:: _tier_weights_t .. py:attribute:: _uniform_partition :type: bool :value: True .. py:method:: _validate_rule_coverage() Ensure every rule index has >= 1 mode (trunk + selector→rule + head_rule). For each rule k, build the combined GF(2) system: [H_trunk; S; H_k] · b = [c_trunk; pack^-1(k); c_k] and verify it's consistent via _solve_gf2_has_solution. With k_select bits encoding the rule index, the selector contributes k_select scalar equations whose RHS is determined by the rule. Also tracks _uniform_partition: True iff n_rules is a power of 2 AND the combined trunk+selector matrix has rank r_trunk + k_select (i.e. the selector is independent of the trunk, so each rule is reachable with the same per-rule mode count). .. py:method:: analytic_mode_count(per_rule = False) Total mode count via per-rule GF(2) rank summation. For each rule k, modes_k = 2^(M·B − rank([trunk; selector; head_k])). Total = Σ_k modes_k. With random full-rank head matrices the per-rule ranks coincide and total = K · modes_0; for small or degenerate configurations they may differ. Multiplied by H^(ndim − M) for unconstrained dims. .. py:attribute:: bits_per_tier :type: list[tuple[int, int]] .. py:attribute:: dims_constrained :type: list[int] .. py:attribute:: head_bit_range :type: tuple[int, int] .. py:attribute:: head_check_count :type: int .. py:attribute:: head_seed :type: int .. py:attribute:: head_weight :type: float .. py:attribute:: k_select :type: int .. py:attribute:: n_rules :type: int .. py:attribute:: parity_checks .. py:method:: tier_indicators(states_tensor) Per-tier independent pass/fail indicators. Returns a list of boolean tensors (one per tier), each of shape ``states_tensor.shape[:-1]``. ``indicators[t]`` is True for states that satisfy tier t's GF(2) parity constraints *independently* (not cumulatively). .. py:attribute:: tier_weights :type: list[float] .. py:class:: ConditionalHyperGrid(*args, **kwargs) Bases: :py:obj:`HyperGrid` HyperGrid environment with condition-aware rewards. Let condition 'c' be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape) .. py:attribute:: _log_partition_cache :type: dict[torch.Tensor, float] .. py:attribute:: _max_reward :type: float .. py:attribute:: _original_reward_fn .. py:attribute:: _true_dist_cache :type: dict[torch.Tensor, torch.Tensor] .. py:attribute:: condition_dim :type: int :value: 1 .. py:attribute:: is_conditional :type: bool :value: True .. py:method:: log_partition(condition) Compute the log partition for the given condition. :param condition: The condition to compute the log partition for. condition.shape should be (1,) :returns: The log partition function, as a float. .. py:method:: reward(states) Compute rewards for the conditional environment. A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original :param states: The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape) :returns: A tensor of shape (*batch_shape,) containing the rewards. .. py:method:: sample_conditions(batch_shape) Sample conditions for the environment. .. py:method:: true_dist(condition) Compute the true distribution for the given condition. :param condition: The condition to compute the true distribution for. :param condition.shape should be: :type condition.shape should be: 1, :returns: The true distribution for the given condition as a 1-dimensional tensor. .. py:class:: ConditionalMultiScaleReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Tiered reward via conditional digit constraints across spatial scales. Curriculum motivation — conditional hierarchy: This reward tests whether a GFlowNet can learn hierarchical, conditional structure: each tier's constraint depends on what was learned at prior tiers, creating the strongest form of compositional transfer among the three reward types. Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit (coarsest spatial scale), tier 1 constrains the next digit conditioned on the coarse digit, and so on. This creates natural distance-correlated difficulty: states near the origin have small coordinates (high digits are 0, trivially passing coarse filters), while states far from the origin have nonzero high digits that must satisfy the filter. Learning coarse-scale structure first provides early training signal and directly informs which fine-scale configurations are valid, enabling compositional transfer for long-horizon credit assignment. Each coordinate is decomposed in base B into L = log_B(H) digits. Tier t constrains digit (L-1-t) — the (t+1)-th most significant digit — via a shifted filter that depends on all coarser-scale digits already constrained, creating a hierarchy where learning coarse-scale structure is prerequisite for predicting fine-scale constraints. Per-dimension constraint at tier t (0-indexed): (d_{L-1-t}(i) + sigma_t(i; r)) mod B < f where sigma_t(i; r) = sum_{k=0}^{t-1} a_{t,k}^{(r)} * d_{L-1-k}(i) mod B is a linear function of coarser-scale digits with seed-derived coefficients, parameterized by the rule index r. Tier 0 has no shift (sigma_0 = 0) and is shared across all rules — it forms the trunk. K-rule structure (n_rules >= 1): - Tier 0 is the shared trunk: constrains the most-significant digit per active dim via filter [0, f). Same across all rules. - The selector projects each state to a rule index r ∈ [0, n_rules) deterministically: r(s) = packed_MSD(s) mod n_rules where packed_MSD = sum_i d_{L-1, i} * base^i across active dims. - Tiers >= 1 use rule-specific shift coefficients derived from (head_seed, r), so each rule has a different head. At n_rules=1, the selector always returns 0 and there is one head, with coefficients derived from `seed` — bit-exact reproduction of the single-rule reward. Optional cross-dimensional constraint at tier t (applies to all rules): sum_i d_{L-1-t}(i) ≡ 0 (mod m_t) Reward form (cumulative — tier t requires all tiers 0..t under the rule): R(s) = R0 + sum_t tier_weights[t] * 1[s satisfies tiers 0..t] Mode count (closed form, total across all rules): Without cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} With cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} / prod_t m_t The total mode count is INVARIANT in n_rules: rules partition the canonical mode set. When n_rules divides f^d_active, the partition is uniform and modes_per_rule = total / n_rules. Partition function (analytic, no enumeration): Z = R0 * H^d + sum_t w_t * modes_t Key kwargs: - R0: float, base reward (default 0.0). - tier_weights: list[float], reward weight per tier. - base: int, digit base B (default 4). H must be a power of B. - filter_width: int, number of passing digit values per tier (default B//2). Constant across tiers to avoid mode collapse at deep tiers. - seed: int, PRNG seed for generating shift coefficients (default 42). - n_rules: int, number of rules K (default 1). Selector partitions tier-0-passing states into K buckets; uniform partition requires K | f^d_active. - head_seed: int, PRNG seed for per-rule head shift coefficients (default: same as seed; ensures n_rules=1 reproduces single-rule). - cross_dim_mods: Optional[list[int|None]], per-tier modular cross-dim constraint. m_t must divide filter_width for exact mode counts. Default: no cross-dim constraints. - active_dims: Optional[list[int]], subset of dims to constrain (default: all dims). Comparison with other compositional rewards: - BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes. - MultiplicativeCoprimeReward: prime factorization with progressive constraint types; knowledge composition — each tier requires understanding the prior tier's structure. - This class: conditional hierarchy — each tier introduces a constraint whose form depends on what was learned at prior tiers. Coarse-to-fine ordering creates distance-correlated difficulty. .. py:attribute:: R0 :type: float .. py:attribute:: _RULE_SEED_STRIDE :type: int :value: 1000003 .. py:method:: __call__(states_tensor) .. py:method:: _extract_digits(x, num_levels) Extract base-B digits from x. :param x: (..., m) integer tensor with coordinate values. :param num_levels: how many digit levels to extract. :returns: List of num_levels tensors, each (..., m), with digit values in [0, B). digits[k] is the k-th digit (scale k), i.e., floor(x / B^k) mod B. .. py:method:: _selector(msd_digits) Map state MSDs (..., d_active) -> rule index in [0, n_rules). Applies the same shift used by tier 0's filter (sigma_0 = 0 + filter_shift[0]) before packing as base-f. For trunk-passing states the shifted MSD is in [0, f) so the packing is bijective on the trunk-passing set; non-trunk-passing states still get a rule index but cannot become modes (their tier-0 check fails). .. py:attribute:: _shift_coeffs_tensor :type: torch.Tensor .. py:attribute:: _uniform_partition :type: bool :value: True .. py:method:: _validate_rule_coverage() Ensure every rule index is hit by at least one trunk-passing state. Trunk = tier 0 = each active dim's shifted MSD in [0, filter_width) AND (if cross_dim_mods[0] is set) MSD-sum mod m_0 == 0. The selector packs MSDs as base-f and mods by n_rules. Implementation uses cyclic-uniformity: trunk-passing MSD patterns map bijectively to integers [0, f^d_active) via the base-f packing, so `pat mod n_rules` distributes uniformly when `n_rules | n_surv`. The cross-dim filter at tier 0 thins this set further; for non-trivial cross_mod_0 we sample to verify coverage. (Old enumeration over f^d_active patterns was intractable at d_active > ~10.) Sets `_uniform_partition` based on whether n_rules divides the trunk-passing count evenly. .. py:attribute:: active_dims :type: list[int] .. py:method:: analytic_log_partition() Compute log(Z) analytically. Z = R0 * H^ndim + sum_t w_t * modes_t .. py:method:: analytic_mode_count(tier = None, per_rule = False) Compute exact mode count for a given tier (1-indexed) or all tiers. Total mode count is invariant in `n_rules` — rules partition the canonical mode set rather than multiplying it. :param tier: 1-indexed tier number. If None, returns count for the highest tier (most constrained). :param per_rule: If True, returns the per-rule count (total // n_rules). Requires uniform partition (n_rules divides the trunk-passing pattern count); raises ValueError otherwise. :returns: Number of states satisfying all constraints up to the given tier (over all rules combined if per_rule=False, per rule otherwise). .. py:attribute:: base :type: int .. py:attribute:: cross_dim_mods :type: list[int | None] :value: [] .. py:attribute:: filter_shift :type: list[int] .. py:attribute:: filter_width :type: int .. py:attribute:: head_seed :type: int .. py:method:: mode_threshold(target_sparsity = 0.1) Return the reward threshold for mode counting at the adaptive tier. States with reward >= this value are counted as modes. The tier is chosen via ``mode_tier(target_sparsity)`` so that mode coverage adapts to dimensionality. :param target_sparsity: Passed to ``mode_tier()``. :returns: Reward threshold (R0 + sum of weights up to the mode tier). .. py:method:: mode_tier(target_sparsity = 0.1) Return the lowest tier whose mode coverage is below target_sparsity. Coverage at tier t = (f/B)^(t*d) (before cross-dim constraints). This adapts the mode definition to dimensionality: at low d, deeper tiers are needed for modes to be sparse; at high d, even tier 1 is already a needle in a haystack. :param target_sparsity: Fraction of total state space below which modes are considered "interesting" (default 0.10 = 10%). :returns: 1-indexed tier number. Clamped to [1, num_tiers]. .. py:attribute:: n_rules :type: int .. py:attribute:: num_levels :type: int :value: 0 .. py:attribute:: seed :type: int .. py:attribute:: shift_coeffs_per_rule :type: list[list[list[int]]] :value: [] .. py:method:: tier_indicators(states_tensor) Per-tier independent pass/fail indicators. Returns a list of boolean tensors (one per tier), each of shape ``states_tensor.shape[:-1]``. ``indicators[t]`` is True for states that satisfy tier t's digit constraint *independently* (not cumulatively). Note: the shift at tier t still depends on coarser digits, so the constraint is state-dependent but evaluated per-tier. .. py:attribute:: tier_weights :type: list[float] .. py:class:: CorruptedReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Wraps a tiered structured reward and applies per-tier corruption. Conceptually, at each tier, a fraction ``corruption_rate`` of states that earned that tier's bonus have it "moved" to a random location. This degrades the compositional structure at every level proportionally. Per-tier corruption logic: For each tier *t* and each state *s*: 1. Compute the base reward's per-tier indicator ``pass_t(s)``. 2. **Demote**: if ``pass_t(s)`` and ``hash(s, seed + 2*t) < corruption_rate``, remove tier *t*'s contribution. 3. **Promote**: if not ``pass_t(s)`` and ``hash(s, seed + 2*t + 1) < replacement_rate_t``, add tier *t*'s contribution. ``replacement_rate_t`` is calibrated at init so that the expected number of promotions matches demotions. Final reward:: R(s) = R0 + sum_t w_t * corrupted_pass_t(s) For non-tiered base rewards, falls back to a single-level binary corruption at the mode threshold. Key kwargs: - base_reward: str, name of the base reward (default "conditional_multiscale"). - base_kwargs: dict, kwargs for the base reward constructor. - corruption_rate: float in [0, 1), fraction of tier-passing states to demote per tier (default 0.2). - seed: int, hash seed (default 137). .. py:attribute:: _REWARD_CLASSES :type: dict[str, type[GridReward]] .. py:method:: __call__(states_tensor) .. py:method:: _call_simple(states_tensor) Fallback for non-tiered base rewards: binary corruption. .. py:method:: _estimate_replacement_rates() Sample states to estimate per-tier pass fraction, then set replacement rates so promotions ~ demotions in expectation. .. py:attribute:: _is_tiered .. py:attribute:: _replacement_rates :type: list[float] :value: [] .. py:attribute:: base_fn :type: GridReward .. py:attribute:: base_reward_str :value: '' .. py:attribute:: corruption_rate :type: float .. py:method:: mode_threshold() Return the mode threshold derived from the base reward. .. py:attribute:: seed :type: int .. py:class:: CosineReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Cosine reward function. .. py:method:: __call__(states_tensor) .. py:class:: DeceptiveReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Deceptive reward function from Adaptive Teachers (Kim et al., 2025). Note that the reward definition in the paper (eq. (9)) is incorrect, and we follow the official implementation (https://github.com/alstn12088/adaptive-teacher/blob/8cfcb2298fce3f46eb36ead03791eeee75b7d066/grid/env.py#L27) while modifying it to use EPS = 1e-12 to handle inequalities with floating points. .. py:method:: __call__(states_tensor) .. py:class:: GridReward(height, ndim, **kwargs) Bases: :py:obj:`abc.ABC` Base class for reward functions that can be pickled. .. py:attribute:: _EPS :value: 1e-12 .. py:method:: __call__(states_tensor) :abstractmethod: .. py:attribute:: height .. py:attribute:: kwargs .. py:attribute:: ndim .. py:class:: HyperGrid(ndim = 2, height = 8, reward_fn_str = 'original', reward_fn_kwargs = None, device = 'cpu', calculate_partition = False, store_all_states = False, debug = False, validate_modes = True, mode_stats = 'none', mode_stats_samples = 20000) Bases: :py:obj:`gfn.env.DiscreteEnv` HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in `{0, 1, ..., height - 1}`. .. attribute:: ndim The dimension of the grid. .. attribute:: height The height of the grid. .. attribute:: reward_fn The reward function. .. attribute:: calculate_partition Whether to calculate the log partition function. .. attribute:: store_all_states Whether to store all states. .. attribute:: validate_modes Whether to check that at least one state reaches the mode threshold at init; raises if not. .. attribute:: mode_stats One of {"none", "approx", "exact"}. If not "none", computes (exact or approximate) `n_modes` and `n_mode_states`. "exact" requires `store_all_states=True` and enumerates all states. .. attribute:: mode_stats_samples Number of random samples when `mode_stats="approx"`. .. py:attribute:: States :type: type[gfn.states.DiscreteStates] .. py:attribute:: _all_states_tensor :value: None .. py:method:: _enumerate_all_states_tensor(batch_size = 20000) Enumerate all grid states, optionally storing them and computing log Z. Iterates over the full Cartesian product ``{0, ..., H-1}^D`` in batches (via multiprocessing) to avoid materializing all ``H^D`` states at once. :param batch_size: Number of states per batch. .. py:method:: _exists_bitwise_xor(thr) Deterministic feasibility check for ``BitwiseXORReward``. Builds the combined GF(2) system [trunk; selector→0; head_0]·b = [c_trunk; 0; c_head_0] for rule 0 and verifies consistency. This works uniformly for n_rules=1 (k_select=0, selector empty) and for K-rule (per-rule coverage was already verified at __init__ by _validate_rule_coverage; this re-checks rule 0 as a defense-in-depth). Feasibility of this combined GF(2) system is necessary and sufficient for a mode to exist; no random sampling is required. Presets use power-of-two heights so every feasible bit-assignment is a valid state (raw coord < height). .. py:method:: _exists_conditional_multiscale(thr) Constructive existence check for ConditionalMultiScaleReward. With filter_shift=[0,...,0] (default) the all-zeros state is always a mode: every per-tier filter passes 0 since (0 + 0) mod B = 0 < f. With non-zero filter_shift, we try a few "all-same-v" candidate states chosen so the MSD passes tier 0; one of them typically passes all deeper tiers when the per-rule shift_coeffs map zero lower digits to zero, leaving the constant filter_shift[t] as the only contribution. .. py:method:: _exists_cosine(thr) Analytic upper-bound check for ``CosineReward``. Idea: - The per-dimension factor is ``(cos(50·ax) + 1) · N(0,1)(5·ax)`` with ax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value ``m``. - The full reward upper bound is ``R0 + m^D * R1``. If this is at least the mode target and the given threshold, a mode-level state must exist. - We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by ``mode_gamma``. .. py:method:: _exists_fallback_random(thr) Random sampling fallback. Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids. .. py:method:: _exists_multiplicative_coprime(thr) Number-theoretic constructive check for ``MultiplicativeCoprimeReward``. For each rule, factors the rule's target LCM over allowed primes, tries permutations of prime-to-active-dim assignments, and checks coprime + grid-bound + selector-match. Returns True iff at least one rule has a witness state whose selector maps back to that rule's index AND whose reward reaches the mode threshold. The reward shifts raw coords by +1 internally (raw 0 → internal 1), so witness states are constructed in raw space as ``p**exp - 1`` per active dim, with coprime pair checks evaluated on the post-shift internal values. At n_rules=1 the selector is trivially 0 and only rule 0 is tried, recovering the legacy behavior. .. py:method:: _exists_original_or_deceptive(thr) Constructive check for ``OriginalReward`` and ``DeceptiveReward``. Intuition: - These rewards form rings/bands around the center when each coordinate is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center. - We translate those fractional band boundaries into integer indices via small inside/outside nudges (using ``EPS_INDEX_CMP``) and test one candidate index from any non-empty feasible interval. - If the reward at that candidate exceeds the threshold (with ``EPS_REWARD_CMP`` tolerance), we return True. .. py:method:: _exists_random_or_corrupted(thr) Check for UniformRandomReward or CorruptedReward. For UniformRandomReward the probe budget is sized so that P(miss all modes | at least one mode exists) < 1e-9, using n = ceil(log(1e-9) / log(1 - mode_prob)). For CorruptedReward a fixed budget of 10 000 is used (mode density is approximately preserved by the promotion/demotion calibration). A seeded generator derived from the reward seed and grid shape makes the result reproducible across calls with the same configuration. .. py:method:: _exists_sparse(thr) Constructive check for ``SparseReward``. This reward assigns positive mass only to a finite set of target configurations. When ``H>=2`` and ``D>=1``, a known target is the zero vector except for certain coordinates fixed at 1 or ``H-2``. We probe a canonical target and confirm the threshold is not above its reward. .. py:method:: _generate_combinations_in_batches(ndim, max_val, batch_size) Yield batches of the Cartesian product {0, ..., max_val}^ndim. Uses multiprocessing to avoid materializing the full product (size ``(max_val+1)^ndim``) in memory. Workers are created via the spawn start method and execute the module-level :func:`_hypergrid_worker` function so the call is safe inside MPI ranks and CUDA contexts (see the start-method comment near the top of this file). Pool size is capped at ``MAX_POOL_WORKERS`` because larger pools just multiply per-rank fork/spawn overhead without shrinking the per-task work — and a 64-core node hosting many co-located MPI ranks can otherwise blow up to thousands of worker processes simultaneously. :param ndim: Number of dimensions (tuple length). :param max_val: Maximum coordinate value (inclusive). :param batch_size: Number of tuples per batch. :Yields: A list of tuples for each batch. .. py:method:: _get_states_indices_bigint(states_raw) Compute canonical indices using arbitrary-precision Python ints. Used by :meth:`get_states_indices` when ``height ** ndim > 2 ** 63`` and the int64 path would overflow. Vectorized over the (potentially large) batch dimension via numpy object-dtype broadcasting: the inner Python loop iterates only over the small feature dimension ``ndim``, and each ``k * h + col`` operation dispatches a single C-level loop over all rows that calls Python ``int.__mul__`` / ``int.__add__`` per element. This is a few times faster than a nested Python loop while still preserving arbitrary-precision correctness. Returns a numpy ``object`` array of shape ``states_raw.shape[:-1]`` containing one Python ``int`` per state. .. py:attribute:: _log_partition :value: None .. py:method:: _mode_reward_threshold() Returns the reward threshold used to define a mode. By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward. .. py:attribute:: _mode_stats_kind :type: str :value: 'none' .. py:method:: _modes_exist_quick_check() Lightweight check that a mode-level state exists. In simple terms, this answers: "Is there at least one state whose reward reaches the mode threshold?" without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly and checks against the threshold. 2) Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold. 3) As a last resort, it samples a small batch of random states. .. py:method:: _modes_exist_quick_check_info() Same as _modes_exist_quick_check but returns (ok, message). .. py:attribute:: _n_mode_states_estimate :type: float | None :value: None .. py:attribute:: _n_mode_states_exact :type: int | None :value: None .. py:method:: _solve_gf2_has_solution(A, c) :staticmethod: Return True if A x = c over GF(2) has at least one solution. Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (``0 = 1``) indicates inconsistency. .. py:method:: _solve_gf2_witness(A, c, n_vars) :staticmethod: Return a witness solution to A·b = c over GF(2), or None if none exists. b has length n_vars. A is reduced via Gaussian elimination; free variables are set to 0. .. py:attribute:: _true_dist :value: None .. py:method:: all_indices() Generate all possible indices for the grid. :returns: A list of all possible indices for the grid. .. py:property:: all_states :type: gfn.states.DiscreteStates | None Returns a tensor of all hypergrid states as a `DiscreteStates` instance. .. py:method:: backward_step(states, actions) Performs a backward step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The previous states. .. py:attribute:: calculate_partition :value: False .. py:method:: get_states_indices(states) Get the canonical ordering indices for a batch of states. Returns one canonical index per state computed from the base-``height`` encoding ``sum(s[j] * height^(ndim-1-j))``. The maximum index is ``height^ndim - 1``. - **Safe regime** (``height ** ndim <= 2 ** 63``): the index fits in signed int64 and we return a ``torch.Tensor`` of shape ``batch_shape`` with dtype ``torch.int64`` (the historical behaviour). - **Overflow regime** (``height ** ndim > 2 ** 63``): the index would overflow int64 and silently wrap, producing collisions between distinct states (a real bug we hit at e.g. ndim=10, height=128 where ``128**10 == 2**70``). In this regime we fall back to per-row Python ``int`` arithmetic and return a ``numpy.ndarray`` of dtype ``object`` containing arbitrary-precision Python ints. Each element is a unique, hashable canonical index. The two return types support the same downstream usages we care about (``set(...tolist())`` for mode tracking, boolean masking with ``[mask]`` after converting the mask to numpy if needed). Code paths that need an ``int64`` tensor for tensor indexing (e.g. ``EnumPreprocessor``) implicitly require the safe regime — they'll see the numpy fallback and fail loudly, which is the correct behavior because such grids are too large to enumerate anyway. :param states: The states to get the indices of. :returns: Indices in canonical ordering. ``torch.Tensor[int64]`` of shape ``batch_shape`` in the safe regime; ``np.ndarray[object]`` of shape ``batch_shape`` containing Python ints in the overflow regime. .. py:method:: get_terminating_states_indices(states) Get the indices of the terminating states in the canonical ordering. See :meth:`get_states_indices` for the return-type contract: a ``torch.Tensor[int64]`` for grids small enough to fit in 62 bits, or a ``numpy.ndarray[object]`` of Python ints for larger grids that would otherwise overflow. :param states: The states to get the indices of. :returns: The indices of the terminating states in the canonical ordering. .. py:attribute:: height :value: 8 .. py:method:: log_partition(condition=None) Returns the log partition of the reward function. .. py:method:: make_random_states(batch_shape, conditions = None, device = None, debug = False) Creates a batch of random states. :param batch_shape: The shape of the batch. :param conditions: Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets. :param device: The device to use. :param debug: If True, emit States with debug guards (not compile-friendly). :returns: A `DiscreteStates` object with random states. .. py:method:: make_states_class() Returns the DiscreteStates class for the HyperGrid environment. .. py:method:: mode_mask(states) Boolean mask indicating which states are in a mode. A state is flagged as mode if its reward is greater-or-equal to the threshold based on `reward_fn_kwargs` (R0+R1+R2 by default). .. py:method:: modes_found(states) Returns the set of canonical state indices for mode states in the batch. Each mode state is identified by its unique canonical index (from ``get_states_indices``), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions. .. py:property:: n_mode_states :type: int | float | None Number of states inside a mode (exact, approx, or None). - If mode_stats="exact", returns an exact integer count. - If mode_stats="approx", returns a floating-point estimate. - If store_all_states is True (but mode_stats was "none"), computes on demand from all_states. - Otherwise, returns None. .. py:property:: n_modes :type: int | float | None Returns the total number of mode states for this environment. Equivalent to ``n_mode_states``. Each individual grid cell whose reward meets the mode threshold counts as one mode. .. py:property:: n_states :type: int Returns the number of states in the environment. .. py:property:: n_terminating_states :type: int Returns the number of terminating states in the environment. .. py:attribute:: ndim :value: 2 .. py:method:: reward(states) Computes the reward for a batch of final states. In the normal setting, the reward is: `R(s) = R_0 + 0.5 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.25, 0.5] \right) + 2 \prod_{d=1}^D \mathbf{1} \left( \left\lvert \frac{s^d}{H-1} - 0.5 \right\rvert \in (0.3, 0.4) \right)` :param final_states: The final states. :returns: The reward of the final states. .. py:attribute:: reward_fn .. py:attribute:: reward_fn_kwargs :value: None .. py:method:: step(states, actions) Performs a step in the environment. :param states: The current states. :param actions: The actions to take. :returns: The next states. .. py:attribute:: store_all_states :value: False .. py:property:: terminating_states :type: gfn.states.DiscreteStates | None Returns all terminating states of the environment. .. py:method:: true_dist(condition=None) Returns the pmf over all states in the hypergrid. .. py:class:: MultiplicativeCoprimeReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Tiered reward based on prime-support and coprimality/lcm composition. Curriculum motivation — knowledge composition: This reward tests whether a GFlowNet can learn number-theoretic structure progressively: first discovering which coordinates factor over allowed primes (tier 0), then learning exponent bounds (tier 1), then cross- dimensional coprimality (tier 2), and finally global LCM targets (tier 3). Each tier builds on knowledge from prior tiers: learning prime factorization at tier 0 is prerequisite for reasoning about exponent caps at tier 1, which in turn enables the coprimality reasoning needed at tier 2. This tests compositional transfer where each level requires a qualitatively different type of constraint, not just more of the same. Coordinates are shifted by +1 internally (state 0 -> value 1) so that the origin is valid and short trajectories immediately encounter small prime-factorable numbers (2, 3, 4, 5, 6, ...), providing early training signal for long-horizon credit assignment. Each tier progressively adds new constraint types: - Tier 0: Prime support — coordinates must factor over allowed primes. - Tier 1+: Exponent caps — prime exponents bounded per tier. - coprime_start_tier+: Coprime pairs — cross-dimensional coupling. - target_lcms: LCM targets — global compositional constraint. Reward form: R(s) = R0 + Σ_t tier_weights[t] · 1[ constraints_0..t all satisfied ] Key kwargs: - R0: float, base reward (default 0.0) - tier_weights: list[float] - primes: list[int], e.g., [2,3,5,7,11]. Primes exceeding height are auto-filtered with a warning. - exponent_caps: list[int], same length as tier_weights. Cap for every prime at tier t (uniform cap across primes for simplicity). Auto-capped to floor(log_p(height)) for each prime p. - active_dims: Optional[list[int]]; constraints only apply to these dims (default: all dims). Other dims are ignored in constraints. - coprime_pairs: Optional[list[tuple[int,int]]]; indices relative to active_dims. - coprime_start_tier: int, first tier at which coprime constraints apply (default: 0, preserving backward compatibility). - target_lcms: Optional[list[int | None | str]]; per-tier target lcm across active dims. Use "auto" to derive from primes and exponent_caps. Notes: - Coordinates are shifted by +1 internally: state value 0 maps to reward value 1, making the origin (0,...,0) trivially valid. - Implementation removes primes up to the current tier cap and checks residue == 1. Exponent counts are accumulated to evaluate LCM targets. Comparison with other compositional rewards: - BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes. - ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints. - This class: prime factorization with progressive constraint types (support -> caps -> coprimality -> LCM). Knowledge composition — each tier requires understanding the prior tier's structure. .. py:attribute:: R0 :type: float .. py:method:: __call__(states_tensor) .. py:method:: _factor_exponents_up_to_cap(v, cap) Trial-divide each element by allowed primes, returning residue and exponents. :param v: (...,) LongTensor of non-negative values to factorize. :param cap: Maximum number of times each prime may divide a value. :returns: (...,) values after stripping allowed primes (1 if fully factored). exps: [num_primes, ...] exponent counts per prime (leading axis is primes). :rtype: residue .. py:method:: _generate_rule_targets(n_rules, seed) Generate n_rules distinct LCM targets from a deterministic enum. Enumerates cap-tuples (cap_p ∈ {1, ..., top_cap}) over allowed primes — cap=0 (prime absent from LCM) is excluded so every rule has a non-trivial target. Total tuples = top_cap^n_primes; permutes by seed and picks the first n_rules. The prime universe must satisfy top_cap^n_primes >= n_rules. Note: cap=0 is intentionally excluded. With cap=0, the LCM head constraint "no active dim has prime p as factor" combined with coprime-pair and exponent-cap constraints typically yields zero or near-zero modes — degenerate rules. .. py:method:: _lcm_ok(exps, target_lcm) Check whether max exponents across dims match target LCM's factorization. :param exps: [num_primes, ..., num_active_dims] exponent counts. :param target_lcm: The target LCM value to match. :returns: (...,) bool mask, True where the LCM of active-dim values equals target. .. py:method:: _pairwise_coprime_ok(v) Check that configured dimension pairs share no common allowed prime. :param v: (..., num_active_dims) coordinate values. :returns: (...,) bool mask, True where all coprime pair constraints hold. .. py:attribute:: _rule_target_exps .. py:method:: _selector(x_active) Map state's active-dim values to a rule index in [0, n_rules). Selector is sum_i x_i mod n_rules (over active dims, shifted values). Returns shape (...,) long tensor. .. py:method:: _validate_rule_coverage() Ensure each rule's LCM target is achievable. A target is achievable iff (a) it factors over allowed primes (already checked at __init__), and (b) every required exponent is <= top cap. Both conditions are necessary; with active_dims >= n_primes_with_nonzero_exp and coprime_pairs constraints, sufficiency requires enumeration. Here we check only the necessary conditions and verify that the rules produce non-degenerate distinct targets. .. py:attribute:: active_dims :type: list[int] .. py:attribute:: coprime_pairs .. py:attribute:: coprime_start_tier :type: int .. py:attribute:: exponent_caps :type: list[int] :value: [] .. py:attribute:: head_seed :type: int .. py:attribute:: n_rules :type: int .. py:attribute:: primes :type: list[int] .. py:attribute:: rule_targets :type: list[int | None] .. py:attribute:: target_lcms :type: list[int | None] :value: [] .. py:method:: tier_indicators(states_tensor) Per-tier independent pass/fail indicators. Returns a list of boolean tensors (one per tier), each of shape ``states_tensor.shape[:-1]``. ``indicators[t]`` is True for states that satisfy tier t's constraints *independently* (not cumulatively). .. py:attribute:: tier_weights :type: list[float] .. py:class:: OriginalReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` The reward function from the original GFlowNet paper (Bengio et al., 2021; https://arxiv.org/abs/2106.04399). .. py:method:: __call__(states_tensor) .. py:class:: SparseReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Sparse reward function from the GAFN paper (Pan et al., 2022; https://arxiv.org/abs/2210.03308). .. py:method:: __call__(states_tensor) .. py:attribute:: targets .. py:class:: UniformRandomReward(height, ndim, **kwargs) Bases: :py:obj:`GridReward` Each state is independently a mode with probability ``mode_prob``. Uses a deterministic hash on state coordinates so mode membership is reproducible without storing or enumerating all states. There is no exploitable spatial or algebraic structure. Reward form:: R(s) = R0 + R_mode if hash(s, seed) < mode_prob R(s) = R0 otherwise Key kwargs: - R0: float, base reward for non-mode states (default 0.1). - R_mode: float, additional reward for mode states (default 2.0). - mode_prob: float in (0, 1), probability each state is a mode (default 0.01). - seed: int, hash seed for reproducibility (default 42). .. py:attribute:: R0 :type: float .. py:attribute:: R_mode :type: float .. py:method:: __call__(states_tensor) .. py:attribute:: mode_prob :type: float .. py:attribute:: seed :type: int .. py:data:: _MAX_POOL_WORKERS :value: 8 .. py:function:: _first_k_dims(k, ndim) Return indices [0, 1, ..., min(k, ndim)-1] for the first k dimensions. .. py:function:: _gf2_random_fullrank(n_checks, n_vars, seed) Generate a full-rank random binary matrix A and target vector c over GF(2). Uses a deterministic seed for reproducibility across runs. :param n_checks: Number of independent GF(2) equations (rows of A). :param n_vars: Number of binary variables (columns of A). :param seed: Deterministic RNG seed. :returns: (A, c) where A is (n_checks, n_vars) int tensor and c is (n_checks,) int tensor, both with values in {0, 1}. A is guaranteed to have full row-rank over GF(2). .. py:function:: _gf2_rank(A) Compute the rank of a binary matrix over GF(2). .. py:function:: _hypergrid_worker(task) Module-level worker for ``HyperGrid._generate_combinations_in_batches``. Returns the requested slice of the Cartesian product as a concrete ``list``. Lives at module level (rather than as a bound method) so it can be pickled to a spawned ``multiprocessing.Pool`` worker — bound methods of ``HyperGrid`` are not picklable because the env's States subclass is created locally inside ``make_states_class``. :param task: ``(values, ndim, start_idx, end_idx)`` where ``values`` is the list of coordinate values, ``ndim`` is the number of dimensions, and ``[start_idx, end_idx)`` is the index range within the full Cartesian product. :returns: A list of length ``end_idx - start_idx`` containing tuples of length ``ndim``. Returning a concrete list (rather than an ``itertools.islice``) keeps the result picklable across workers and future-proofs against the Python 3.14 removal of itertools pickle support. .. py:function:: _preset_seed(name) Deterministic seed from a preset name. .. py:function:: _state_hash_uniform(states_tensor, seed) Deterministic hash mapping each grid state to a float in [0, 1). Uses a polynomial rolling hash over coordinate values computed in int64 arithmetic. Suitable for pseudo-random but reproducible per-state decisions (e.g., mode assignment, corruption masks). :param states_tensor: (..., ndim) integer tensor of coordinates. :param seed: Integer seed for determinism. :returns: -1] with values in [0.0, 1.0). :rtype: Tensor of shape states_tensor.shape[ .. py:function:: get_bitwise_xor_presets(ndim, height) Return five difficulty presets for BitwiseXORReward. Difficulty is controlled by the number of constrained dimensions M. More constrained dims means more independent GF(2) checks per bit position, leading to fewer modes. Each preset uses 3 tiers with increasing numbers of GF(2) checks: - Tier 0 (curriculum): few checks, many states pass - Tier 1 (intermediate): moderate checks - Tier 2 (mode): strictest checks, defines the modes Mode counts for ndim=10, height=16 (B=4 bit-planes). Per-bit-position solutions = 2^(M − cum_top_checks); raised to the B-th power, then multiplied by 16^(ndim − M) free-dim configurations: - easy (M=3, cum=2): ~4.3B modes (16 × 16^7) - medium (M=5, cum=4): ~16.8M modes (16 × 16^5) - hard (M=8, cum=6): ~65K modes (256 × 16^2) - challenging (M=10, cum=7): ~4K modes (4096 × 1) - impossible (M=12→10, cum=9): 16 modes (16 × 1; M capped to ndim) Notes - Uses fixed seeds per preset name for reproducibility. - Parity checks are random full-rank GF(2) matrices. - Bit ranges are capped to ceil(log2(height)) - 1. .. py:function:: get_conditional_multiscale_presets(ndim, height) Return difficulty presets for ConditionalMultiScaleReward. All presets use base=4 (requiring H to be a power of 4). The number of available digit levels is L = log_4(H). Difficulty is controlled by: - Number of tiers (more tiers = deeper compositional hierarchy) - Number of active dims (more = exponentially sparser modes) - Cross-dim modular constraints (further sparsification) Mode counts are computed via: modes_T = (f^T * 4^{L-T})^d * H^(ndim-d) / prod_t m_t with f = filter_width = 2 (i.e. B//2), so each tier halves modes per coord. Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit. Near the origin (small coordinates), high digits are 0 and trivially pass the filter. Deeper tiers constrain progressively finer digits, creating distance-correlated difficulty. Two preset families: Difficulty presets (single-rule legacy, assuming H=256, i.e. L=4): - easy: 2 tiers, 3 active dims -> ~2M modes at tier 2 - medium: 3 tiers, 4 active dims -> ~1M modes at tier 3 - hard: 3 tiers, 6 active dims, cross-dim -> ~260K modes at tier 3 - challenging: 4 tiers, 8 active dims, cross-dim -> ~65K modes at tier 4 - impossible: 4 tiers, 12 active dims, cross-dim -> ~4K modes at tier 4 K-rule trunk+heads presets (sparsity matched, K rules sharing tier-0 trunk): - K1, K16, K64: 3 tiers, 6 active dims, cross_dim_mods=[2,2,2]. Total modes invariant in K (~32K total at H=64, density ~5e-7); rules partition the canonical mode set. Designed for ndim=6, H=64. If height provides fewer digit levels than a preset requires, the preset's tier_weights and cross_dim_mods are auto-truncated with a warning. .. py:function:: get_corrupted_presets(ndim, height) Return five difficulty presets for CorruptedReward. Each preset wraps a ``conditional_multiscale`` "medium" base and applies increasing corruption. A single ``corruption_rate`` parameter controls the fraction of per-tier structure that is randomized. Difficulty progression: - easy: 10% corruption -> mostly structured - medium: 30% corruption -> noticeable randomness - hard: 50% corruption -> half structured, half random - challenging: 70% corruption -> mostly random - impossible: 90% corruption -> near-total randomness Note: requires ``height`` to be a power of 4 (same as the base reward). .. py:function:: get_cosine_presets(ndim, height) Return five presets for CosineReward. R1 scales the oscillatory product, and `mode_gamma` (used only for mode detection thresholding) tightens what is considered a "mode-like" maximum. .. py:function:: get_deceptive_presets(ndim, height) Return five presets for DeceptiveReward. Increase R2 to accentuate the thin band, and set a small but non-zero R0. R1 controls the center emphasis vs. the cancelled outer region. .. py:function:: get_multiplicative_coprime_presets(ndim, height) Return five difficulty presets for MultiplicativeCoprimeReward. Each preset uses progressive tier structure where each tier adds a new constraint type: - Tier 0: Prime support only (coords must factor over allowed primes) - Tier 1: + Exponent caps (tighten factorization) - coprime_start_tier+: + Coprime pairs (cross-dim coupling) - Final tier: + LCM target (global compositional constraint) Coordinates are shifted +1 internally (origin -> all-ones), so short trajectories immediately encounter small prime-factorable numbers. Primes exceeding height and exponent caps exceeding log_p(height) are auto-filtered/capped in the reward constructor. Notes - `active_dims` indexes are relative to state dims; we pick first k. - `coprime_pairs` are pairs within `active_dims` index space. - Tier weights are geometric. - Use target_lcms="auto" to derive from primes and exponent_caps. .. py:function:: get_original_presets(ndim, height) Return five presets for OriginalReward. These presets primarily control the relative importance of the outer ring (R1) and thin band (R2). Exploration difficulty (distance from s0) is more a function of (D, H) than of these weights; tune D and H externally to match your distance bands. .. py:function:: get_reward_presets(reward_fn_str, ndim, height) Return presets for a given reward function name. Usage ---- presets = get_reward_presets("bitwise_xor", D, H) kwargs = presets["hard"] env = HyperGrid(ndim=D, height=H, reward_fn_str="bitwise_xor", reward_fn_kwargs=kwargs) .. py:function:: get_sparse_presets(ndim, height) Return five presets for SparseReward. SparseReward has built-in targets; it ignores most kwargs. Presets are provided for API symmetry and future extensibility. .. py:function:: get_uniform_random_presets(ndim, height) Return five difficulty presets for UniformRandomReward. Difficulty is controlled by ``mode_prob``: lower probability means sparser modes, which are harder for GFlowNets to discover. .. py:function:: lcm(a, b) Returns the lowest common multiple between a and b. .. py:function:: lcm_multiple(numbers) Find the lowest common multiple across a list of numbers .. py:data:: logger .. py:function:: smallest_multiplier_to_integers(float_vector, precision=3) Used to calculate a scale factor to avoid imprecise floating point arithmetic.