gfn.gym.hypergrid

Adapted from https://github.com/Tikquuss/GflowNets_Tutorial

Attributes

_MAX_POOL_WORKERS

logger

Classes

BitwiseXORReward

Tiered, compositional reward based on bitwise XOR/parity constraints.

ConditionalHyperGrid

HyperGrid environment with condition-aware rewards.

ConditionalMultiScaleReward

Tiered reward via conditional digit constraints across spatial scales.

CorruptedReward

Wraps a tiered structured reward and applies per-tier corruption.

CosineReward

Cosine reward function.

DeceptiveReward

Deceptive reward function from Adaptive Teachers (Kim et al., 2025).

GridReward

Base class for reward functions that can be pickled.

HyperGrid

HyperGrid environment from the GFlowNets paper.

MultiplicativeCoprimeReward

Tiered reward based on prime-support and coprimality/lcm composition.

OriginalReward

The reward function from the original GFlowNet paper (Bengio et al., 2021;

SparseReward

Sparse reward function from the GAFN paper (Pan et al., 2022;

UniformRandomReward

Each state is independently a mode with probability mode_prob.

Functions

_first_k_dims(k, ndim)

Return indices [0, 1, ..., min(k, ndim)-1] for the first k dimensions.

_gf2_random_fullrank(n_checks, n_vars, seed)

Generate a full-rank random binary matrix A and target vector c

_gf2_rank(A)

Compute the rank of a binary matrix over GF(2).

_hypergrid_worker(task)

Module-level worker for HyperGrid._generate_combinations_in_batches.

_preset_seed(name)

Deterministic seed from a preset name.

_state_hash_uniform(states_tensor, seed)

Deterministic hash mapping each grid state to a float in [0, 1).

get_bitwise_xor_presets(ndim, height)

Return five difficulty presets for BitwiseXORReward.

get_conditional_multiscale_presets(ndim, height)

Return difficulty presets for ConditionalMultiScaleReward.

get_corrupted_presets(ndim, height)

Return five difficulty presets for CorruptedReward.

get_cosine_presets(ndim, height)

Return five presets for CosineReward.

get_deceptive_presets(ndim, height)

Return five presets for DeceptiveReward.

get_multiplicative_coprime_presets(ndim, height)

Return five difficulty presets for MultiplicativeCoprimeReward.

get_original_presets(ndim, height)

Return five presets for OriginalReward.

get_reward_presets(reward_fn_str, ndim, height)

Return presets for a given reward function name.

get_sparse_presets(ndim, height)

Return five presets for SparseReward.

get_uniform_random_presets(ndim, height)

Return five difficulty presets for UniformRandomReward.

lcm(a, b)

Returns the lowest common multiple between a and b.

lcm_multiple(numbers)

Find the lowest common multiple across a list of numbers

smallest_multiplier_to_integers(float_vector[, precision])

Used to calculate a scale factor to avoid imprecise floating point arithmetic.

Module Contents

class gfn.gym.hypergrid.BitwiseXORReward(height, ndim, **kwargs)

Bases: GridReward

Tiered, compositional reward based on bitwise XOR/parity constraints.

Curriculum motivation — rule reuse:

This reward tests whether a GFlowNet can learn a global algebraic rule (GF(2) parity) and reuse it across tiers of increasing strictness. Unlike the other compositional rewards, modes are NOT spatially concentrated near the origin — they are distributed non-locally across the grid according to algebraic structure. This is intentional: it probes the model’s ability to learn abstract, non-spatial compositionality.

The curriculum operates through constraint accumulation: tier 0 applies few parity checks (many modes, easy to discover), tier 1 adds more checks (fewer modes, same rule type), etc. A model that learns the parity computation at tier 0 can reuse that same computation to satisfy tier 1+ constraints, providing a form of compositional transfer for long-horizon credit assignment.

This class implements the “Bitwise/XOR fractal” environment family: where tiers progressively constrain bit-planes across a subset of dimensions via linear parity checks over GF(2). It supports easy sharding by high-bit prefixes, and difficulty control by adjusting which bit-planes and how many dimensions are constrained per tier.

GF(2) is the finite field with two elements {0, 1}, where addition and multiplication are performed modulo 2. In this context, vector addition is equivalent to bitwise XOR, and matrix-vector products (A @ b) are evaluated entrywise modulo 2.

Reward form:

R(s) = R0 + Σ_t tier_weights[t] · 1[ state satisfies all constraints up to tier t ]

Key kwargs (with reasonable defaults):
  • R0: float, base reward (default 0.0)

  • tier_weights: list[float], strictly increasing weights for each tier

  • dims_constrained: Optional[list[int]] subset of dims to constrain (default: all dims)

  • bits_per_tier: list[tuple[int,int]]; for each tier t, inclusive bit range (low_bit, high_bit). Example: [(0,5), (0,7), (0,9)].

  • parity_checks: Optional[list[dict]]; per tier, optional parity system:
    Each entry may contain:

    { “A”: IntTensor[num_checks, m], “c”: IntTensor[num_checks] }

    where m = len(dims_constrained). Constraints apply identically to every bit-plane specified for that tier: A @ b(mod2) == c, where b are the bit values across constrained dimensions at the tested bit-plane. If omitted for a tier, a single even-parity check across all constrained dims is used by default: sum(b) mod 2 == 0.

Difficulty presets align with step ranges by controlling the highest bit used and the number of constrained dimensions. Typical distance from origin for valid modes scales roughly like (constrained_dims · 2^{highest_bit}).

K-rule structure (n_rules >= 1):
  • Trunk: the per-tier parity stack above, shared across all rules.

  • Selector: a fixed GF(2) matrix S of shape (k_select, M*B) projects bits to a rule index r = pack(S·b mod 2) ∈ [0, n_rules). Here k_select = ceil(log2(n_rules)); for n_rules=1, k_select=0 and r is always 0.

  • Head: per-rule parity matrix H_r of shape (head_check_count, M), applied at every bit-plane in head_bit_range. Mode iff trunk passes AND H_r·b == c_r at the head’s bit-planes.

Reward:
R = R0 + Σ_t w_t·1[trunk_0..t pass]
  • head_weight · 1[trunk all pass ∧ head_{σ(b)} pass]

At n_rules=1 with head_check_count=0 and head_weight=0 (defaults), the head is empty and the K-rule code path collapses to the legacy reward bit-exactly.

Total mode count is invariant in n_rules when it’s a power of 2 — the selector adds k_select bits that partition the trunk-passing space, and per-rule head adds head_check_count·n_head_bits bits per coset.

Comparison with other compositional rewards:
  • MultiplicativeCoprimeReward: number-theoretic (prime factorization); knowledge composition — learning prime structure enables coprimality and LCM constraints at higher tiers.

  • ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints.

  • This class: GF(2) linear algebra on bit-planes; rule reuse — the same parity check type is applied with increasing strictness per tier. Modes are non-local (algebraic, not spatial).

Parameters:
  • height (int)

  • ndim (int)

R0: float
_B
_RULE_SEED_STRIDE: int = 1000003
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_apply_parity_checks(bits_plane, tier_idx)

Apply GF(2) linear parity checks at a single bit-plane.

bits_plane: (…, m) with m=len(dims_constrained), integer in {0,1}. Returns mask (…,) bool.

Parameters:
  • bits_plane (torch.Tensor)

  • tier_idx (int)

Return type:

torch.Tensor

_bit_positions
_dim_idx
_even_parity_mask(bits)

bits: (…, m) int/bool -> returns (…,) bool for even parity.

Parameters:

bits (torch.Tensor)

Return type:

torch.Tensor

_head_A_per_rule
_head_c_per_rule
_per_rule_mode_counts()

Bit-config mode count per rule (without unconstrained-dim factor).

For rule k, returns 2^(M·B − rank([trunk; selector; head_k])) when the combined system is consistent, else 0. Inconsistency means the rule is unreachable — its bit-config mode count is 0.

Return type:

list[int]

_select_powers
_tier_check_counts = []
_tier_weights_t
_uniform_partition: bool = True
_validate_rule_coverage()

Ensure every rule index has >= 1 mode (trunk + selector→rule + head_rule).

For each rule k, build the combined GF(2) system:

[H_trunk; S; H_k] · b = [c_trunk; pack^-1(k); c_k]

and verify it’s consistent via _solve_gf2_has_solution. With k_select bits encoding the rule index, the selector contributes k_select scalar equations whose RHS is determined by the rule.

Also tracks _uniform_partition: True iff n_rules is a power of 2 AND the combined trunk+selector matrix has rank r_trunk + k_select (i.e. the selector is independent of the trunk, so each rule is reachable with the same per-rule mode count).

Return type:

None

analytic_mode_count(per_rule=False)

Total mode count via per-rule GF(2) rank summation.

For each rule k, modes_k = 2^(M·B − rank([trunk; selector; head_k])). Total = Σ_k modes_k. With random full-rank head matrices the per-rule ranks coincide and total = K · modes_0; for small or degenerate configurations they may differ. Multiplied by H^(ndim − M) for unconstrained dims.

Parameters:

per_rule (bool)

Return type:

int

bits_per_tier: list[tuple[int, int]]
dims_constrained: list[int]
head_bit_range: tuple[int, int]
head_check_count: int
head_seed: int
head_weight: float
k_select: int
n_rules: int
parity_checks
tier_indicators(states_tensor)

Per-tier independent pass/fail indicators.

Returns a list of boolean tensors (one per tier), each of shape states_tensor.shape[:-1]. indicators[t] is True for states that satisfy tier t’s GF(2) parity constraints independently (not cumulatively).

Parameters:

states_tensor (torch.Tensor)

Return type:

list[torch.Tensor]

tier_weights: list[float]
class gfn.gym.hypergrid.ConditionalHyperGrid(*args, **kwargs)

Bases: HyperGrid

HyperGrid environment with condition-aware rewards.

Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)

_log_partition_cache: dict[torch.Tensor, float]
_max_reward: float
_original_reward_fn
_true_dist_cache: dict[torch.Tensor, torch.Tensor]
condition_dim: int = 1
is_conditional: bool = True
log_partition(condition)

Compute the log partition for the given condition.

Parameters:

condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)

Returns:

The log partition function, as a float.

Return type:

float

reward(states)

Compute rewards for the conditional environment.

A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original

Parameters:

states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)

Returns:

A tensor of shape (*batch_shape,) containing the rewards.

Return type:

torch.Tensor

sample_conditions(batch_shape)

Sample conditions for the environment.

Parameters:

batch_shape (int | tuple[int, Ellipsis])

Return type:

torch.Tensor

true_dist(condition)

Compute the true distribution for the given condition.

Parameters:
  • condition (torch.Tensor) – The condition to compute the true distribution for.

  • be (condition.shape should)

Returns:

The true distribution for the given condition as a 1-dimensional tensor.

Return type:

torch.Tensor

class gfn.gym.hypergrid.ConditionalMultiScaleReward(height, ndim, **kwargs)

Bases: GridReward

Tiered reward via conditional digit constraints across spatial scales.

Curriculum motivation — conditional hierarchy:

This reward tests whether a GFlowNet can learn hierarchical, conditional structure: each tier’s constraint depends on what was learned at prior tiers, creating the strongest form of compositional transfer among the three reward types.

Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit (coarsest spatial scale), tier 1 constrains the next digit conditioned on the coarse digit, and so on. This creates natural distance-correlated difficulty: states near the origin have small coordinates (high digits are 0, trivially passing coarse filters), while states far from the origin have nonzero high digits that must satisfy the filter. Learning coarse-scale structure first provides early training signal and directly informs which fine-scale configurations are valid, enabling compositional transfer for long-horizon credit assignment.

Each coordinate is decomposed in base B into L = log_B(H) digits. Tier t constrains digit (L-1-t) — the (t+1)-th most significant digit — via a shifted filter that depends on all coarser-scale digits already constrained, creating a hierarchy where learning coarse-scale structure is prerequisite for predicting fine-scale constraints.

Per-dimension constraint at tier t (0-indexed):

(d_{L-1-t}(i) + sigma_t(i; r)) mod B < f

where sigma_t(i; r) = sum_{k=0}^{t-1} a_{t,k}^{(r)} * d_{L-1-k}(i) mod B is a linear function of coarser-scale digits with seed-derived coefficients, parameterized by the rule index r. Tier 0 has no shift (sigma_0 = 0) and is shared across all rules — it forms the trunk.

K-rule structure (n_rules >= 1):
  • Tier 0 is the shared trunk: constrains the most-significant digit per active dim via filter [0, f). Same across all rules.

  • The selector projects each state to a rule index r ∈ [0, n_rules) deterministically: r(s) = packed_MSD(s) mod n_rules where packed_MSD = sum_i d_{L-1, i} * base^i across active dims.

  • Tiers >= 1 use rule-specific shift coefficients derived from (head_seed, r), so each rule has a different head.

At n_rules=1, the selector always returns 0 and there is one head, with coefficients derived from seed — bit-exact reproduction of the single-rule reward.

Optional cross-dimensional constraint at tier t (applies to all rules):

sum_i d_{L-1-t}(i) ≡ 0 (mod m_t)

Reward form (cumulative — tier t requires all tiers 0..t under the rule):

R(s) = R0 + sum_t tier_weights[t] * 1[s satisfies tiers 0..t]

Mode count (closed form, total across all rules):

Without cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} With cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} / prod_t m_t

The total mode count is INVARIANT in n_rules: rules partition the canonical mode set. When n_rules divides f^d_active, the partition is uniform and modes_per_rule = total / n_rules.

Partition function (analytic, no enumeration):

Z = R0 * H^d + sum_t w_t * modes_t

Key kwargs:
  • R0: float, base reward (default 0.0).

  • tier_weights: list[float], reward weight per tier.

  • base: int, digit base B (default 4). H must be a power of B.

  • filter_width: int, number of passing digit values per tier (default B//2). Constant across tiers to avoid mode collapse at deep tiers.

  • seed: int, PRNG seed for generating shift coefficients (default 42).

  • n_rules: int, number of rules K (default 1). Selector partitions tier-0-passing states into K buckets; uniform partition requires K | f^d_active.

  • head_seed: int, PRNG seed for per-rule head shift coefficients (default: same as seed; ensures n_rules=1 reproduces single-rule).

  • cross_dim_mods: Optional[list[int|None]], per-tier modular cross-dim constraint. m_t must divide filter_width for exact mode counts. Default: no cross-dim constraints.

  • active_dims: Optional[list[int]], subset of dims to constrain (default: all dims).

Comparison with other compositional rewards:
  • BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes.

  • MultiplicativeCoprimeReward: prime factorization with progressive constraint types; knowledge composition — each tier requires understanding the prior tier’s structure.

  • This class: conditional hierarchy — each tier introduces a constraint whose form depends on what was learned at prior tiers. Coarse-to-fine ordering creates distance-correlated difficulty.

Parameters:
  • height (int)

  • ndim (int)

R0: float
_RULE_SEED_STRIDE: int = 1000003
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_extract_digits(x, num_levels)

Extract base-B digits from x.

Parameters:
  • x (torch.Tensor) – (…, m) integer tensor with coordinate values.

  • num_levels (int) – how many digit levels to extract.

Returns:

List of num_levels tensors, each (…, m), with digit values in [0, B). digits[k] is the k-th digit (scale k), i.e., floor(x / B^k) mod B.

Return type:

list[torch.Tensor]

_selector(msd_digits)

Map state MSDs (…, d_active) -> rule index in [0, n_rules).

Applies the same shift used by tier 0’s filter (sigma_0 = 0 + filter_shift[0]) before packing as base-f. For trunk-passing states the shifted MSD is in [0, f) so the packing is bijective on the trunk-passing set; non-trunk-passing states still get a rule index but cannot become modes (their tier-0 check fails).

Parameters:

msd_digits (torch.Tensor)

Return type:

torch.Tensor

_shift_coeffs_tensor: torch.Tensor
_uniform_partition: bool = True
_validate_rule_coverage()

Ensure every rule index is hit by at least one trunk-passing state.

Trunk = tier 0 = each active dim’s shifted MSD in [0, filter_width) AND (if cross_dim_mods[0] is set) MSD-sum mod m_0 == 0. The selector packs MSDs as base-f and mods by n_rules.

Implementation uses cyclic-uniformity: trunk-passing MSD patterns map bijectively to integers [0, f^d_active) via the base-f packing, so pat mod n_rules distributes uniformly when n_rules | n_surv. The cross-dim filter at tier 0 thins this set further; for non-trivial cross_mod_0 we sample to verify coverage. (Old enumeration over f^d_active patterns was intractable at d_active > ~10.)

Sets _uniform_partition based on whether n_rules divides the trunk-passing count evenly.

Return type:

None

active_dims: list[int]
analytic_log_partition()

Compute log(Z) analytically.

Z = R0 * H^ndim + sum_t w_t * modes_t

Return type:

float

analytic_mode_count(tier=None, per_rule=False)

Compute exact mode count for a given tier (1-indexed) or all tiers.

Total mode count is invariant in n_rules — rules partition the canonical mode set rather than multiplying it.

Parameters:
  • tier (int | None) – 1-indexed tier number. If None, returns count for the highest tier (most constrained).

  • per_rule (bool) – If True, returns the per-rule count (total // n_rules). Requires uniform partition (n_rules divides the trunk-passing pattern count); raises ValueError otherwise.

Returns:

Number of states satisfying all constraints up to the given tier (over all rules combined if per_rule=False, per rule otherwise).

Return type:

int

base: int
cross_dim_mods: list[int | None] = []
filter_shift: list[int]
filter_width: int
head_seed: int
mode_threshold(target_sparsity=0.1)

Return the reward threshold for mode counting at the adaptive tier.

States with reward >= this value are counted as modes. The tier is chosen via mode_tier(target_sparsity) so that mode coverage adapts to dimensionality.

Parameters:

target_sparsity (float) – Passed to mode_tier().

Returns:

Reward threshold (R0 + sum of weights up to the mode tier).

Return type:

float

mode_tier(target_sparsity=0.1)

Return the lowest tier whose mode coverage is below target_sparsity.

Coverage at tier t = (f/B)^(t*d) (before cross-dim constraints). This adapts the mode definition to dimensionality: at low d, deeper tiers are needed for modes to be sparse; at high d, even tier 1 is already a needle in a haystack.

Parameters:

target_sparsity (float) – Fraction of total state space below which modes are considered “interesting” (default 0.10 = 10%).

Returns:

1-indexed tier number. Clamped to [1, num_tiers].

Return type:

int

n_rules: int
num_levels: int = 0
seed: int
shift_coeffs_per_rule: list[list[list[int]]] = []
tier_indicators(states_tensor)

Per-tier independent pass/fail indicators.

Returns a list of boolean tensors (one per tier), each of shape states_tensor.shape[:-1]. indicators[t] is True for states that satisfy tier t’s digit constraint independently (not cumulatively). Note: the shift at tier t still depends on coarser digits, so the constraint is state-dependent but evaluated per-tier.

Parameters:

states_tensor (torch.Tensor)

Return type:

list[torch.Tensor]

tier_weights: list[float]
class gfn.gym.hypergrid.CorruptedReward(height, ndim, **kwargs)

Bases: GridReward

Wraps a tiered structured reward and applies per-tier corruption.

Conceptually, at each tier, a fraction corruption_rate of states that earned that tier’s bonus have it “moved” to a random location. This degrades the compositional structure at every level proportionally.

Per-tier corruption logic:

For each tier t and each state s:

  1. Compute the base reward’s per-tier indicator pass_t(s).

  2. Demote: if pass_t(s) and hash(s, seed + 2*t) < corruption_rate, remove tier t’s contribution.

  3. Promote: if not pass_t(s) and hash(s, seed + 2*t + 1) < replacement_rate_t, add tier t’s contribution. replacement_rate_t is calibrated at init so that the expected number of promotions matches demotions.

Final reward:

R(s) = R0 + sum_t w_t * corrupted_pass_t(s)

For non-tiered base rewards, falls back to a single-level binary corruption at the mode threshold.

Key kwargs:
  • base_reward: str, name of the base reward (default “conditional_multiscale”).

  • base_kwargs: dict, kwargs for the base reward constructor.

  • corruption_rate: float in [0, 1), fraction of tier-passing states to demote per tier (default 0.2).

  • seed: int, hash seed (default 137).

Parameters:
  • height (int)

  • ndim (int)

_REWARD_CLASSES: dict[str, type[GridReward]]
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_call_simple(states_tensor)

Fallback for non-tiered base rewards: binary corruption.

Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_estimate_replacement_rates()

Sample states to estimate per-tier pass fraction, then set replacement rates so promotions ~ demotions in expectation.

Return type:

None

_is_tiered
_replacement_rates: list[float] = []
base_fn: GridReward
base_reward_str = ''
corruption_rate: float
mode_threshold()

Return the mode threshold derived from the base reward.

Return type:

float

seed: int
class gfn.gym.hypergrid.CosineReward(height, ndim, **kwargs)

Bases: GridReward

Cosine reward function.

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.DeceptiveReward(height, ndim, **kwargs)

Bases: GridReward

Deceptive reward function from Adaptive Teachers (Kim et al., 2025).

Note that the reward definition in the paper (eq. (9)) is incorrect, and we follow the official implementation (https://github.com/alstn12088/adaptive-teacher/blob/8cfcb2298fce3f46eb36ead03791eeee75b7d066/grid/env.py#L27) while modifying it to use EPS = 1e-12 to handle inequalities with floating points.

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.GridReward(height, ndim, **kwargs)

Bases: abc.ABC

Base class for reward functions that can be pickled.

Parameters:
  • height (int)

  • ndim (int)

_EPS = 1e-12
abstract __call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

height
kwargs
ndim
class gfn.gym.hypergrid.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)

Bases: gfn.env.DiscreteEnv

HyperGrid environment from the GFlowNets paper.

The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.

Parameters:
  • ndim (int)

  • height (int)

  • reward_fn_str (str)

  • reward_fn_kwargs (dict | None)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • calculate_partition (bool)

  • store_all_states (bool)

  • debug (bool)

  • validate_modes (bool)

  • mode_stats (Literal['none', 'approx', 'exact'])

  • mode_stats_samples (int)

ndim

The dimension of the grid.

height

The height of the grid.

reward_fn

The reward function.

calculate_partition

Whether to calculate the log partition function.

store_all_states

Whether to store all states.

validate_modes

Whether to check that at least one state reaches the mode threshold at init; raises if not.

mode_stats

One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.

mode_stats_samples

Number of random samples when mode_stats=”approx”.

States: type[gfn.states.DiscreteStates]
_all_states_tensor = None
_enumerate_all_states_tensor(batch_size=20000)

Enumerate all grid states, optionally storing them and computing log Z.

Iterates over the full Cartesian product {0, ..., H-1}^D in batches (via multiprocessing) to avoid materializing all H^D states at once.

Parameters:

batch_size (int) – Number of states per batch.

_exists_bitwise_xor(thr)

Deterministic feasibility check for BitwiseXORReward.

Builds the combined GF(2) system [trunk; selector→0; head_0]·b = [c_trunk; 0; c_head_0] for rule 0 and verifies consistency. This works uniformly for n_rules=1 (k_select=0, selector empty) and for K-rule (per-rule coverage was already verified at __init__ by _validate_rule_coverage; this re-checks rule 0 as a defense-in-depth).

Feasibility of this combined GF(2) system is necessary and sufficient for a mode to exist; no random sampling is required. Presets use power-of-two heights so every feasible bit-assignment is a valid state (raw coord < height).

Parameters:

thr (float)

Return type:

bool

_exists_conditional_multiscale(thr)

Constructive existence check for ConditionalMultiScaleReward.

With filter_shift=[0,…,0] (default) the all-zeros state is always a mode: every per-tier filter passes 0 since (0 + 0) mod B = 0 < f. With non-zero filter_shift, we try a few “all-same-v” candidate states chosen so the MSD passes tier 0; one of them typically passes all deeper tiers when the per-rule shift_coeffs map zero lower digits to zero, leaving the constant filter_shift[t] as the only contribution.

Parameters:

thr (float)

Return type:

bool

_exists_cosine(thr)

Analytic upper-bound check for CosineReward.

Idea: - The per-dimension factor is (cos(50·ax) + 1) · N(0,1)(5·ax) with

ax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value m.

  • The full reward upper bound is R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.

  • We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by mode_gamma.

Parameters:

thr (float)

Return type:

bool

_exists_fallback_random(thr)

Random sampling fallback.

Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.

Parameters:

thr (float)

Return type:

bool

_exists_multiplicative_coprime(thr)

Number-theoretic constructive check for MultiplicativeCoprimeReward.

For each rule, factors the rule’s target LCM over allowed primes, tries permutations of prime-to-active-dim assignments, and checks coprime + grid-bound + selector-match. Returns True iff at least one rule has a witness state whose selector maps back to that rule’s index AND whose reward reaches the mode threshold.

The reward shifts raw coords by +1 internally (raw 0 → internal 1), so witness states are constructed in raw space as p**exp - 1 per active dim, with coprime pair checks evaluated on the post-shift internal values.

At n_rules=1 the selector is trivially 0 and only rule 0 is tried, recovering the legacy behavior.

Parameters:

thr (float)

Return type:

bool

_exists_original_or_deceptive(thr)

Constructive check for OriginalReward and DeceptiveReward.

Intuition: - These rewards form rings/bands around the center when each coordinate

is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.

  • We translate those fractional band boundaries into integer indices via small inside/outside nudges (using EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.

  • If the reward at that candidate exceeds the threshold (with EPS_REWARD_CMP tolerance), we return True.

Parameters:

thr (float)

Return type:

bool

_exists_random_or_corrupted(thr)

Check for UniformRandomReward or CorruptedReward.

For UniformRandomReward the probe budget is sized so that P(miss all modes | at least one mode exists) < 1e-9, using n = ceil(log(1e-9) / log(1 - mode_prob)). For CorruptedReward a fixed budget of 10 000 is used (mode density is approximately preserved by the promotion/demotion calibration).

A seeded generator derived from the reward seed and grid shape makes the result reproducible across calls with the same configuration.

Parameters:

thr (float)

Return type:

bool

_exists_sparse(thr)

Constructive check for SparseReward.

This reward assigns positive mass only to a finite set of target configurations. When H>=2 and D>=1, a known target is the zero vector except for certain coordinates fixed at 1 or H-2. We probe a canonical target and confirm the threshold is not above its reward.

Parameters:

thr (float)

Return type:

bool

_generate_combinations_in_batches(ndim, max_val, batch_size)

Yield batches of the Cartesian product {0, …, max_val}^ndim.

Uses multiprocessing to avoid materializing the full product (size (max_val+1)^ndim) in memory.

Workers are created via the spawn start method and execute the module-level _hypergrid_worker() function so the call is safe inside MPI ranks and CUDA contexts (see the start-method comment near the top of this file). Pool size is capped at MAX_POOL_WORKERS because larger pools just multiply per-rank fork/spawn overhead without shrinking the per-task work — and a 64-core node hosting many co-located MPI ranks can otherwise blow up to thousands of worker processes simultaneously.

Parameters:
  • ndim (int) – Number of dimensions (tuple length).

  • max_val (int) – Maximum coordinate value (inclusive).

  • batch_size (int) – Number of tuples per batch.

Yields:

A list of tuples for each batch.

_get_states_indices_bigint(states_raw)

Compute canonical indices using arbitrary-precision Python ints.

Used by get_states_indices() when height ** ndim > 2 ** 63 and the int64 path would overflow.

Vectorized over the (potentially large) batch dimension via numpy object-dtype broadcasting: the inner Python loop iterates only over the small feature dimension ndim, and each k * h + col operation dispatches a single C-level loop over all rows that calls Python int.__mul__ / int.__add__ per element. This is a few times faster than a nested Python loop while still preserving arbitrary-precision correctness.

Returns a numpy object array of shape states_raw.shape[:-1] containing one Python int per state.

Parameters:

states_raw (torch.Tensor)

Return type:

numpy.ndarray

_log_partition = None
_mode_reward_threshold()

Returns the reward threshold used to define a mode.

By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.

Return type:

float

_mode_stats_kind: str = 'none'
_modes_exist_quick_check()

Lightweight check that a mode-level state exists.

In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly

and checks against the threshold.

  1. Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.

  2. As a last resort, it samples a small batch of random states.

Return type:

bool

_modes_exist_quick_check_info()

Same as _modes_exist_quick_check but returns (ok, message).

Return type:

tuple[bool, str]

_n_mode_states_estimate: float | None = None
_n_mode_states_exact: int | None = None
static _solve_gf2_has_solution(A, c)

Return True if A x = c over GF(2) has at least one solution.

Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (0 = 1) indicates inconsistency.

Parameters:
  • A (torch.Tensor)

  • c (torch.Tensor)

Return type:

bool

static _solve_gf2_witness(A, c, n_vars)

Return a witness solution to A·b = c over GF(2), or None if none exists.

b has length n_vars. A is reduced via Gaussian elimination; free variables are set to 0.

Parameters:
  • A (torch.Tensor)

  • c (torch.Tensor)

  • n_vars (int)

Return type:

torch.Tensor | None

_true_dist = None
all_indices()

Generate all possible indices for the grid.

Returns:

A list of all possible indices for the grid.

Return type:

List[Tuple[int, Ellipsis]]

property all_states: gfn.states.DiscreteStates | None

Returns a tensor of all hypergrid states as a DiscreteStates instance.

Return type:

gfn.states.DiscreteStates | None

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.DiscreteStates

calculate_partition = False
get_states_indices(states)

Get the canonical ordering indices for a batch of states.

Returns one canonical index per state computed from the base-height encoding sum(s[j] * height^(ndim-1-j)). The maximum index is height^ndim - 1.

  • Safe regime (height ** ndim <= 2 ** 63): the index fits in signed int64 and we return a torch.Tensor of shape batch_shape with dtype torch.int64 (the historical behaviour).

  • Overflow regime (height ** ndim > 2 ** 63): the index would overflow int64 and silently wrap, producing collisions between distinct states (a real bug we hit at e.g. ndim=10, height=128 where 128**10 == 2**70). In this regime we fall back to per-row Python int arithmetic and return a numpy.ndarray of dtype object containing arbitrary-precision Python ints. Each element is a unique, hashable canonical index.

The two return types support the same downstream usages we care about (set(...tolist()) for mode tracking, boolean masking with [mask] after converting the mask to numpy if needed). Code paths that need an int64 tensor for tensor indexing (e.g. EnumPreprocessor) implicitly require the safe regime — they’ll see the numpy fallback and fail loudly, which is the correct behavior because such grids are too large to enumerate anyway.

Parameters:

states (Union[gfn.states.DiscreteStates, torch.Tensor]) – The states to get the indices of.

Returns:

Indices in canonical ordering. torch.Tensor[int64] of shape batch_shape in the safe regime; np.ndarray[object] of shape batch_shape containing Python ints in the overflow regime.

Return type:

Union[torch.Tensor, numpy.ndarray]

get_terminating_states_indices(states)

Get the indices of the terminating states in the canonical ordering.

See get_states_indices() for the return-type contract: a torch.Tensor[int64] for grids small enough to fit in 62 bits, or a numpy.ndarray[object] of Python ints for larger grids that would otherwise overflow.

Parameters:

states (gfn.states.DiscreteStates) – The states to get the indices of.

Returns:

The indices of the terminating states in the canonical ordering.

Return type:

Union[torch.Tensor, numpy.ndarray]

height = 8
log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float | None

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Creates a batch of random states.

Parameters:
  • batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A DiscreteStates object with random states.

Return type:

gfn.states.DiscreteStates

make_states_class()

Returns the DiscreteStates class for the HyperGrid environment.

Return type:

type[gfn.states.DiscreteStates]

mode_mask(states)

Boolean mask indicating which states are in a mode.

A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).

Parameters:

states (gfn.states.DiscreteStates)

Return type:

torch.Tensor

modes_found(states)

Returns the set of canonical state indices for mode states in the batch.

Each mode state is identified by its unique canonical index (from get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.

Parameters:

states (gfn.states.DiscreteStates)

Return type:

set[int]

property n_mode_states: int | float | None

Number of states inside a mode (exact, approx, or None).

  • If mode_stats=”exact”, returns an exact integer count.

  • If mode_stats=”approx”, returns a floating-point estimate.

  • If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.

  • Otherwise, returns None.

Return type:

int | float | None

property n_modes: int | float | None

Returns the total number of mode states for this environment.

Equivalent to n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.

Return type:

int | float | None

property n_states: int

Returns the number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states in the environment.

Return type:

int

ndim = 2
reward(states)

Computes the reward for a batch of final states.

In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}

  • 0.5 rightrvert in (0.25, 0.5] right)

  • 2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`

Parameters:
Returns:

The reward of the final states.

Return type:

torch.Tensor

reward_fn
reward_fn_kwargs = None
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.DiscreteStates

store_all_states = False
property terminating_states: gfn.states.DiscreteStates | None

Returns all terminating states of the environment.

Return type:

gfn.states.DiscreteStates | None

true_dist(condition=None)

Returns the pmf over all states in the hypergrid.

Return type:

torch.Tensor | None

class gfn.gym.hypergrid.MultiplicativeCoprimeReward(height, ndim, **kwargs)

Bases: GridReward

Tiered reward based on prime-support and coprimality/lcm composition.

Curriculum motivation — knowledge composition:

This reward tests whether a GFlowNet can learn number-theoretic structure progressively: first discovering which coordinates factor over allowed primes (tier 0), then learning exponent bounds (tier 1), then cross- dimensional coprimality (tier 2), and finally global LCM targets (tier 3).

Each tier builds on knowledge from prior tiers: learning prime factorization at tier 0 is prerequisite for reasoning about exponent caps at tier 1, which in turn enables the coprimality reasoning needed at tier 2. This tests compositional transfer where each level requires a qualitatively different type of constraint, not just more of the same.

Coordinates are shifted by +1 internally (state 0 -> value 1) so that the origin is valid and short trajectories immediately encounter small prime-factorable numbers (2, 3, 4, 5, 6, …), providing early training signal for long-horizon credit assignment.

Each tier progressively adds new constraint types:
  • Tier 0: Prime support — coordinates must factor over allowed primes.

  • Tier 1+: Exponent caps — prime exponents bounded per tier.

  • coprime_start_tier+: Coprime pairs — cross-dimensional coupling.

  • target_lcms: LCM targets — global compositional constraint.

Reward form:

R(s) = R0 + Σ_t tier_weights[t] · 1[ constraints_0..t all satisfied ]

Key kwargs:
  • R0: float, base reward (default 0.0)

  • tier_weights: list[float]

  • primes: list[int], e.g., [2,3,5,7,11]. Primes exceeding height are auto-filtered with a warning.

  • exponent_caps: list[int], same length as tier_weights. Cap for every prime at tier t (uniform cap across primes for simplicity). Auto-capped to floor(log_p(height)) for each prime p.

  • active_dims: Optional[list[int]]; constraints only apply to these dims (default: all dims). Other dims are ignored in constraints.

  • coprime_pairs: Optional[list[tuple[int,int]]]; indices relative to active_dims.

  • coprime_start_tier: int, first tier at which coprime constraints apply (default: 0, preserving backward compatibility).

  • target_lcms: Optional[list[int | None | str]]; per-tier target lcm across active dims. Use “auto” to derive from primes and exponent_caps.

Notes: - Coordinates are shifted by +1 internally: state value 0 maps to reward

value 1, making the origin (0,…,0) trivially valid.

  • Implementation removes primes up to the current tier cap and checks residue == 1. Exponent counts are accumulated to evaluate LCM targets.

Comparison with other compositional rewards:
  • BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes.

  • ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints.

  • This class: prime factorization with progressive constraint types (support -> caps -> coprimality -> LCM). Knowledge composition — each tier requires understanding the prior tier’s structure.

Parameters:
  • height (int)

  • ndim (int)

R0: float
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_factor_exponents_up_to_cap(v, cap)

Trial-divide each element by allowed primes, returning residue and exponents.

Parameters:
  • v (torch.Tensor) – (…,) LongTensor of non-negative values to factorize.

  • cap (int) – Maximum number of times each prime may divide a value.

Returns:

(…,) values after stripping allowed primes (1 if fully factored). exps: [num_primes, …] exponent counts per prime (leading axis is primes).

Return type:

residue

_generate_rule_targets(n_rules, seed)

Generate n_rules distinct LCM targets from a deterministic enum.

Enumerates cap-tuples (cap_p ∈ {1, …, top_cap}) over allowed primes — cap=0 (prime absent from LCM) is excluded so every rule has a non-trivial target. Total tuples = top_cap^n_primes; permutes by seed and picks the first n_rules. The prime universe must satisfy top_cap^n_primes >= n_rules.

Note: cap=0 is intentionally excluded. With cap=0, the LCM head constraint “no active dim has prime p as factor” combined with coprime-pair and exponent-cap constraints typically yields zero or near-zero modes — degenerate rules.

Parameters:
  • n_rules (int)

  • seed (int)

Return type:

list[int]

_lcm_ok(exps, target_lcm)

Check whether max exponents across dims match target LCM’s factorization.

Parameters:
  • exps (torch.Tensor) – [num_primes, …, num_active_dims] exponent counts.

  • target_lcm (int) – The target LCM value to match.

Returns:

(…,) bool mask, True where the LCM of active-dim values equals target.

Return type:

torch.Tensor

_pairwise_coprime_ok(v)

Check that configured dimension pairs share no common allowed prime.

Parameters:

v (torch.Tensor) – (…, num_active_dims) coordinate values.

Returns:

(…,) bool mask, True where all coprime pair constraints hold.

Return type:

torch.Tensor

_rule_target_exps
_selector(x_active)

Map state’s active-dim values to a rule index in [0, n_rules).

Selector is sum_i x_i mod n_rules (over active dims, shifted values). Returns shape (…,) long tensor.

Parameters:

x_active (torch.Tensor)

Return type:

torch.Tensor

_validate_rule_coverage()

Ensure each rule’s LCM target is achievable.

A target is achievable iff (a) it factors over allowed primes (already checked at __init__), and (b) every required exponent is <= top cap. Both conditions are necessary; with active_dims >= n_primes_with_nonzero_exp and coprime_pairs constraints, sufficiency requires enumeration. Here we check only the necessary conditions and verify that the rules produce non-degenerate distinct targets.

Return type:

None

active_dims: list[int]
coprime_pairs
coprime_start_tier: int
exponent_caps: list[int] = []
head_seed: int
n_rules: int
primes: list[int]
rule_targets: list[int | None]
target_lcms: list[int | None] = []
tier_indicators(states_tensor)

Per-tier independent pass/fail indicators.

Returns a list of boolean tensors (one per tier), each of shape states_tensor.shape[:-1]. indicators[t] is True for states that satisfy tier t’s constraints independently (not cumulatively).

Parameters:

states_tensor (torch.Tensor)

Return type:

list[torch.Tensor]

tier_weights: list[float]
class gfn.gym.hypergrid.OriginalReward(height, ndim, **kwargs)

Bases: GridReward

The reward function from the original GFlowNet paper (Bengio et al., 2021; https://arxiv.org/abs/2106.04399).

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.SparseReward(height, ndim, **kwargs)

Bases: GridReward

Sparse reward function from the GAFN paper (Pan et al., 2022; https://arxiv.org/abs/2210.03308).

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

targets
class gfn.gym.hypergrid.UniformRandomReward(height, ndim, **kwargs)

Bases: GridReward

Each state is independently a mode with probability mode_prob.

Uses a deterministic hash on state coordinates so mode membership is reproducible without storing or enumerating all states. There is no exploitable spatial or algebraic structure.

Reward form:

R(s) = R0 + R_mode   if hash(s, seed) < mode_prob
R(s) = R0             otherwise
Key kwargs:
  • R0: float, base reward for non-mode states (default 0.1).

  • R_mode: float, additional reward for mode states (default 2.0).

  • mode_prob: float in (0, 1), probability each state is a mode (default 0.01).

  • seed: int, hash seed for reproducibility (default 42).

Parameters:
  • height (int)

  • ndim (int)

R0: float
R_mode: float
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

mode_prob: float
seed: int
gfn.gym.hypergrid._MAX_POOL_WORKERS = 8
gfn.gym.hypergrid._first_k_dims(k, ndim)

Return indices [0, 1, …, min(k, ndim)-1] for the first k dimensions.

Parameters:
  • k (int)

  • ndim (int)

Return type:

list[int]

gfn.gym.hypergrid._gf2_random_fullrank(n_checks, n_vars, seed)

Generate a full-rank random binary matrix A and target vector c over GF(2).

Uses a deterministic seed for reproducibility across runs.

Parameters:
  • n_checks (int) – Number of independent GF(2) equations (rows of A).

  • n_vars (int) – Number of binary variables (columns of A).

  • seed (int) – Deterministic RNG seed.

Returns:

(A, c) where A is (n_checks, n_vars) int tensor and c is (n_checks,) int tensor, both with values in {0, 1}. A is guaranteed to have full row-rank over GF(2).

Return type:

tuple[torch.Tensor, torch.Tensor]

gfn.gym.hypergrid._gf2_rank(A)

Compute the rank of a binary matrix over GF(2).

Parameters:

A (torch.Tensor)

Return type:

int

gfn.gym.hypergrid._hypergrid_worker(task)

Module-level worker for HyperGrid._generate_combinations_in_batches.

Returns the requested slice of the Cartesian product as a concrete list. Lives at module level (rather than as a bound method) so it can be pickled to a spawned multiprocessing.Pool worker — bound methods of HyperGrid are not picklable because the env’s States subclass is created locally inside make_states_class.

Parameters:

task(values, ndim, start_idx, end_idx) where values is the list of coordinate values, ndim is the number of dimensions, and [start_idx, end_idx) is the index range within the full Cartesian product.

Returns:

A list of length end_idx - start_idx containing tuples of length ndim. Returning a concrete list (rather than an itertools.islice) keeps the result picklable across workers and future-proofs against the Python 3.14 removal of itertools pickle support.

gfn.gym.hypergrid._preset_seed(name)

Deterministic seed from a preset name.

Parameters:

name (str)

Return type:

int

gfn.gym.hypergrid._state_hash_uniform(states_tensor, seed)

Deterministic hash mapping each grid state to a float in [0, 1).

Uses a polynomial rolling hash over coordinate values computed in int64 arithmetic. Suitable for pseudo-random but reproducible per-state decisions (e.g., mode assignment, corruption masks).

Parameters:
  • states_tensor (torch.Tensor) – (…, ndim) integer tensor of coordinates.

  • seed (int) – Integer seed for determinism.

Returns:

-1] with values in [0.0, 1.0).

Return type:

Tensor of shape states_tensor.shape[

gfn.gym.hypergrid.get_bitwise_xor_presets(ndim, height)

Return five difficulty presets for BitwiseXORReward.

Difficulty is controlled by the number of constrained dimensions M. More constrained dims means more independent GF(2) checks per bit position, leading to fewer modes.

Each preset uses 3 tiers with increasing numbers of GF(2) checks:
  • Tier 0 (curriculum): few checks, many states pass

  • Tier 1 (intermediate): moderate checks

  • Tier 2 (mode): strictest checks, defines the modes

Mode counts for ndim=10, height=16 (B=4 bit-planes). Per-bit-position solutions = 2^(M − cum_top_checks); raised to the B-th power, then multiplied by 16^(ndim − M) free-dim configurations:

  • easy (M=3, cum=2): ~4.3B modes (16 × 16^7)

  • medium (M=5, cum=4): ~16.8M modes (16 × 16^5)

  • hard (M=8, cum=6): ~65K modes (256 × 16^2)

  • challenging (M=10, cum=7): ~4K modes (4096 × 1)

  • impossible (M=12→10, cum=9): 16 modes (16 × 1; M capped to ndim)

Notes - Uses fixed seeds per preset name for reproducibility. - Parity checks are random full-rank GF(2) matrices. - Bit ranges are capped to ceil(log2(height)) - 1.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_conditional_multiscale_presets(ndim, height)

Return difficulty presets for ConditionalMultiScaleReward.

All presets use base=4 (requiring H to be a power of 4). The number of available digit levels is L = log_4(H). Difficulty is controlled by:

  • Number of tiers (more tiers = deeper compositional hierarchy)

  • Number of active dims (more = exponentially sparser modes)

  • Cross-dim modular constraints (further sparsification)

Mode counts are computed via:

modes_T = (f^T * 4^{L-T})^d * H^(ndim-d) / prod_t m_t

with f = filter_width = 2 (i.e. B//2), so each tier halves modes per coord.

Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit. Near the origin (small coordinates), high digits are 0 and trivially pass the filter. Deeper tiers constrain progressively finer digits, creating distance-correlated difficulty.

Two preset families:

Difficulty presets (single-rule legacy, assuming H=256, i.e. L=4):
  • easy: 2 tiers, 3 active dims -> ~2M modes at tier 2

  • medium: 3 tiers, 4 active dims -> ~1M modes at tier 3

  • hard: 3 tiers, 6 active dims, cross-dim -> ~260K modes at tier 3

  • challenging: 4 tiers, 8 active dims, cross-dim -> ~65K modes at tier 4

  • impossible: 4 tiers, 12 active dims, cross-dim -> ~4K modes at tier 4

K-rule trunk+heads presets (sparsity matched, K rules sharing tier-0 trunk):
  • K1, K16, K64: 3 tiers, 6 active dims, cross_dim_mods=[2,2,2]. Total modes invariant in K (~32K total at H=64, density ~5e-7); rules partition the canonical mode set. Designed for ndim=6, H=64.

If height provides fewer digit levels than a preset requires, the preset’s tier_weights and cross_dim_mods are auto-truncated with a warning.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_corrupted_presets(ndim, height)

Return five difficulty presets for CorruptedReward.

Each preset wraps a conditional_multiscale “medium” base and applies increasing corruption. A single corruption_rate parameter controls the fraction of per-tier structure that is randomized.

Difficulty progression:
  • easy: 10% corruption -> mostly structured

  • medium: 30% corruption -> noticeable randomness

  • hard: 50% corruption -> half structured, half random

  • challenging: 70% corruption -> mostly random

  • impossible: 90% corruption -> near-total randomness

Note: requires height to be a power of 4 (same as the base reward).

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_cosine_presets(ndim, height)

Return five presets for CosineReward.

R1 scales the oscillatory product, and mode_gamma (used only for mode detection thresholding) tightens what is considered a “mode-like” maximum.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_deceptive_presets(ndim, height)

Return five presets for DeceptiveReward.

Increase R2 to accentuate the thin band, and set a small but non-zero R0. R1 controls the center emphasis vs. the cancelled outer region.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_multiplicative_coprime_presets(ndim, height)

Return five difficulty presets for MultiplicativeCoprimeReward.

Each preset uses progressive tier structure where each tier adds a new constraint type:

  • Tier 0: Prime support only (coords must factor over allowed primes)

  • Tier 1: + Exponent caps (tighten factorization)

  • coprime_start_tier+: + Coprime pairs (cross-dim coupling)

  • Final tier: + LCM target (global compositional constraint)

Coordinates are shifted +1 internally (origin -> all-ones), so short trajectories immediately encounter small prime-factorable numbers.

Primes exceeding height and exponent caps exceeding log_p(height) are auto-filtered/capped in the reward constructor.

Notes - active_dims indexes are relative to state dims; we pick first k. - coprime_pairs are pairs within active_dims index space. - Tier weights are geometric. - Use target_lcms=”auto” to derive from primes and exponent_caps.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_original_presets(ndim, height)

Return five presets for OriginalReward.

These presets primarily control the relative importance of the outer ring (R1) and thin band (R2). Exploration difficulty (distance from s0) is more a function of (D, H) than of these weights; tune D and H externally to match your distance bands.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_reward_presets(reward_fn_str, ndim, height)

Return presets for a given reward function name.

Usage

presets = get_reward_presets(“bitwise_xor”, D, H) kwargs = presets[“hard”] env = HyperGrid(ndim=D, height=H, reward_fn_str=”bitwise_xor”, reward_fn_kwargs=kwargs)

Parameters:
  • reward_fn_str (str)

  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_sparse_presets(ndim, height)

Return five presets for SparseReward.

SparseReward has built-in targets; it ignores most kwargs. Presets are provided for API symmetry and future extensibility.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_uniform_random_presets(ndim, height)

Return five difficulty presets for UniformRandomReward.

Difficulty is controlled by mode_prob: lower probability means sparser modes, which are harder for GFlowNets to discover.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.lcm(a, b)

Returns the lowest common multiple between a and b.

gfn.gym.hypergrid.lcm_multiple(numbers)

Find the lowest common multiple across a list of numbers

gfn.gym.hypergrid.logger
gfn.gym.hypergrid.smallest_multiplier_to_integers(float_vector, precision=3)

Used to calculate a scale factor to avoid imprecise floating point arithmetic.