gfn.gym.hypergrid¶
Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
Attributes¶
Classes¶
Tiered, compositional reward based on bitwise XOR/parity constraints. |
|
HyperGrid environment with condition-aware rewards. |
|
Tiered reward via conditional digit constraints across spatial scales. |
|
Wraps a tiered structured reward and applies per-tier corruption. |
|
Cosine reward function. |
|
Deceptive reward function from Adaptive Teachers (Kim et al., 2025). |
|
Base class for reward functions that can be pickled. |
|
HyperGrid environment from the GFlowNets paper. |
|
Tiered reward based on prime-support and coprimality/lcm composition. |
|
The reward function from the original GFlowNet paper (Bengio et al., 2021; |
|
Sparse reward function from the GAFN paper (Pan et al., 2022; |
|
Each state is independently a mode with probability |
Functions¶
|
Return indices [0, 1, ..., min(k, ndim)-1] for the first k dimensions. |
|
Generate a full-rank random binary matrix A and target vector c |
|
Compute the rank of a binary matrix over GF(2). |
|
Module-level worker for |
|
Deterministic seed from a preset name. |
|
Deterministic hash mapping each grid state to a float in [0, 1). |
|
Return five difficulty presets for BitwiseXORReward. |
|
Return difficulty presets for ConditionalMultiScaleReward. |
|
Return five difficulty presets for CorruptedReward. |
|
Return five presets for CosineReward. |
|
Return five presets for DeceptiveReward. |
|
Return five difficulty presets for MultiplicativeCoprimeReward. |
|
Return five presets for OriginalReward. |
|
Return presets for a given reward function name. |
|
Return five presets for SparseReward. |
|
Return five difficulty presets for UniformRandomReward. |
|
Returns the lowest common multiple between a and b. |
|
Find the lowest common multiple across a list of numbers |
|
Used to calculate a scale factor to avoid imprecise floating point arithmetic. |
Module Contents¶
- class gfn.gym.hypergrid.BitwiseXORReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered, compositional reward based on bitwise XOR/parity constraints.
- Curriculum motivation — rule reuse:
This reward tests whether a GFlowNet can learn a global algebraic rule (GF(2) parity) and reuse it across tiers of increasing strictness. Unlike the other compositional rewards, modes are NOT spatially concentrated near the origin — they are distributed non-locally across the grid according to algebraic structure. This is intentional: it probes the model’s ability to learn abstract, non-spatial compositionality.
The curriculum operates through constraint accumulation: tier 0 applies few parity checks (many modes, easy to discover), tier 1 adds more checks (fewer modes, same rule type), etc. A model that learns the parity computation at tier 0 can reuse that same computation to satisfy tier 1+ constraints, providing a form of compositional transfer for long-horizon credit assignment.
This class implements the “Bitwise/XOR fractal” environment family: where tiers progressively constrain bit-planes across a subset of dimensions via linear parity checks over GF(2). It supports easy sharding by high-bit prefixes, and difficulty control by adjusting which bit-planes and how many dimensions are constrained per tier.
GF(2) is the finite field with two elements {0, 1}, where addition and multiplication are performed modulo 2. In this context, vector addition is equivalent to bitwise XOR, and matrix-vector products (A @ b) are evaluated entrywise modulo 2.
- Reward form:
R(s) = R0 + Σ_t tier_weights[t] · 1[ state satisfies all constraints up to tier t ]
- Key kwargs (with reasonable defaults):
R0: float, base reward (default 0.0)
tier_weights: list[float], strictly increasing weights for each tier
dims_constrained: Optional[list[int]] subset of dims to constrain (default: all dims)
bits_per_tier: list[tuple[int,int]]; for each tier t, inclusive bit range (low_bit, high_bit). Example: [(0,5), (0,7), (0,9)].
- parity_checks: Optional[list[dict]]; per tier, optional parity system:
- Each entry may contain:
{ “A”: IntTensor[num_checks, m], “c”: IntTensor[num_checks] }
where m = len(dims_constrained). Constraints apply identically to every bit-plane specified for that tier: A @ b(mod2) == c, where b are the bit values across constrained dimensions at the tested bit-plane. If omitted for a tier, a single even-parity check across all constrained dims is used by default: sum(b) mod 2 == 0.
Difficulty presets align with step ranges by controlling the highest bit used and the number of constrained dimensions. Typical distance from origin for valid modes scales roughly like (constrained_dims · 2^{highest_bit}).
- K-rule structure (n_rules >= 1):
Trunk: the per-tier parity stack above, shared across all rules.
Selector: a fixed GF(2) matrix S of shape (k_select, M*B) projects bits to a rule index r = pack(S·b mod 2) ∈ [0, n_rules). Here k_select = ceil(log2(n_rules)); for n_rules=1, k_select=0 and r is always 0.
Head: per-rule parity matrix H_r of shape (head_check_count, M), applied at every bit-plane in head_bit_range. Mode iff trunk passes AND H_r·b == c_r at the head’s bit-planes.
- Reward:
- R = R0 + Σ_t w_t·1[trunk_0..t pass]
head_weight · 1[trunk all pass ∧ head_{σ(b)} pass]
At n_rules=1 with head_check_count=0 and head_weight=0 (defaults), the head is empty and the K-rule code path collapses to the legacy reward bit-exactly.
Total mode count is invariant in n_rules when it’s a power of 2 — the selector adds k_select bits that partition the trunk-passing space, and per-rule head adds head_check_count·n_head_bits bits per coset.
- Comparison with other compositional rewards:
MultiplicativeCoprimeReward: number-theoretic (prime factorization); knowledge composition — learning prime structure enables coprimality and LCM constraints at higher tiers.
ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints.
This class: GF(2) linear algebra on bit-planes; rule reuse — the same parity check type is applied with increasing strictness per tier. Modes are non-local (algebraic, not spatial).
- Parameters:
height (int)
ndim (int)
- R0: float¶
- _B¶
- _RULE_SEED_STRIDE: int = 1000003¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _apply_parity_checks(bits_plane, tier_idx)¶
Apply GF(2) linear parity checks at a single bit-plane.
bits_plane: (…, m) with m=len(dims_constrained), integer in {0,1}. Returns mask (…,) bool.
- Parameters:
bits_plane (torch.Tensor)
tier_idx (int)
- Return type:
torch.Tensor
- _bit_positions¶
- _dim_idx¶
- _even_parity_mask(bits)¶
bits: (…, m) int/bool -> returns (…,) bool for even parity.
- Parameters:
bits (torch.Tensor)
- Return type:
torch.Tensor
- _head_A_per_rule¶
- _head_c_per_rule¶
- _per_rule_mode_counts()¶
Bit-config mode count per rule (without unconstrained-dim factor).
For rule k, returns 2^(M·B − rank([trunk; selector; head_k])) when the combined system is consistent, else 0. Inconsistency means the rule is unreachable — its bit-config mode count is 0.
- Return type:
list[int]
- _select_powers¶
- _tier_check_counts = []¶
- _tier_weights_t¶
- _uniform_partition: bool = True¶
- _validate_rule_coverage()¶
Ensure every rule index has >= 1 mode (trunk + selector→rule + head_rule).
- For each rule k, build the combined GF(2) system:
[H_trunk; S; H_k] · b = [c_trunk; pack^-1(k); c_k]
and verify it’s consistent via _solve_gf2_has_solution. With k_select bits encoding the rule index, the selector contributes k_select scalar equations whose RHS is determined by the rule.
Also tracks _uniform_partition: True iff n_rules is a power of 2 AND the combined trunk+selector matrix has rank r_trunk + k_select (i.e. the selector is independent of the trunk, so each rule is reachable with the same per-rule mode count).
- Return type:
None
- analytic_mode_count(per_rule=False)¶
Total mode count via per-rule GF(2) rank summation.
For each rule k, modes_k = 2^(M·B − rank([trunk; selector; head_k])). Total = Σ_k modes_k. With random full-rank head matrices the per-rule ranks coincide and total = K · modes_0; for small or degenerate configurations they may differ. Multiplied by H^(ndim − M) for unconstrained dims.
- Parameters:
per_rule (bool)
- Return type:
int
- bits_per_tier: list[tuple[int, int]]¶
- dims_constrained: list[int]¶
- head_bit_range: tuple[int, int]¶
- head_check_count: int¶
- head_seed: int¶
- head_weight: float¶
- k_select: int¶
- n_rules: int¶
- parity_checks¶
- tier_indicators(states_tensor)¶
Per-tier independent pass/fail indicators.
Returns a list of boolean tensors (one per tier), each of shape
states_tensor.shape[:-1].indicators[t]is True for states that satisfy tier t’s GF(2) parity constraints independently (not cumulatively).- Parameters:
states_tensor (torch.Tensor)
- Return type:
list[torch.Tensor]
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.ConditionalHyperGrid(*args, **kwargs)¶
Bases:
HyperGridHyperGrid environment with condition-aware rewards.
Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)
- _log_partition_cache: dict[torch.Tensor, float]¶
- _max_reward: float¶
- _original_reward_fn¶
- _true_dist_cache: dict[torch.Tensor, torch.Tensor]¶
- condition_dim: int = 1¶
- is_conditional: bool = True¶
- log_partition(condition)¶
Compute the log partition for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)
- Returns:
The log partition function, as a float.
- Return type:
float
- reward(states)¶
Compute rewards for the conditional environment.
A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original
- Parameters:
states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)
- Returns:
A tensor of shape (*batch_shape,) containing the rewards.
- Return type:
torch.Tensor
- sample_conditions(batch_shape)¶
Sample conditions for the environment.
- Parameters:
batch_shape (int | tuple[int, Ellipsis])
- Return type:
torch.Tensor
- true_dist(condition)¶
Compute the true distribution for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the true distribution for.
be (condition.shape should)
- Returns:
The true distribution for the given condition as a 1-dimensional tensor.
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.ConditionalMultiScaleReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered reward via conditional digit constraints across spatial scales.
- Curriculum motivation — conditional hierarchy:
This reward tests whether a GFlowNet can learn hierarchical, conditional structure: each tier’s constraint depends on what was learned at prior tiers, creating the strongest form of compositional transfer among the three reward types.
Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit (coarsest spatial scale), tier 1 constrains the next digit conditioned on the coarse digit, and so on. This creates natural distance-correlated difficulty: states near the origin have small coordinates (high digits are 0, trivially passing coarse filters), while states far from the origin have nonzero high digits that must satisfy the filter. Learning coarse-scale structure first provides early training signal and directly informs which fine-scale configurations are valid, enabling compositional transfer for long-horizon credit assignment.
Each coordinate is decomposed in base B into L = log_B(H) digits. Tier t constrains digit (L-1-t) — the (t+1)-th most significant digit — via a shifted filter that depends on all coarser-scale digits already constrained, creating a hierarchy where learning coarse-scale structure is prerequisite for predicting fine-scale constraints.
- Per-dimension constraint at tier t (0-indexed):
(d_{L-1-t}(i) + sigma_t(i; r)) mod B < f
where sigma_t(i; r) = sum_{k=0}^{t-1} a_{t,k}^{(r)} * d_{L-1-k}(i) mod B is a linear function of coarser-scale digits with seed-derived coefficients, parameterized by the rule index r. Tier 0 has no shift (sigma_0 = 0) and is shared across all rules — it forms the trunk.
- K-rule structure (n_rules >= 1):
Tier 0 is the shared trunk: constrains the most-significant digit per active dim via filter [0, f). Same across all rules.
The selector projects each state to a rule index r ∈ [0, n_rules) deterministically: r(s) = packed_MSD(s) mod n_rules where packed_MSD = sum_i d_{L-1, i} * base^i across active dims.
Tiers >= 1 use rule-specific shift coefficients derived from (head_seed, r), so each rule has a different head.
At n_rules=1, the selector always returns 0 and there is one head, with coefficients derived from seed — bit-exact reproduction of the single-rule reward.
- Optional cross-dimensional constraint at tier t (applies to all rules):
sum_i d_{L-1-t}(i) ≡ 0 (mod m_t)
- Reward form (cumulative — tier t requires all tiers 0..t under the rule):
R(s) = R0 + sum_t tier_weights[t] * 1[s satisfies tiers 0..t]
- Mode count (closed form, total across all rules):
Without cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} With cross-dim: modes_T = (f^T)^d * B^{(L-T)*d} / prod_t m_t
The total mode count is INVARIANT in n_rules: rules partition the canonical mode set. When n_rules divides f^d_active, the partition is uniform and modes_per_rule = total / n_rules.
- Partition function (analytic, no enumeration):
Z = R0 * H^d + sum_t w_t * modes_t
- Key kwargs:
R0: float, base reward (default 0.0).
tier_weights: list[float], reward weight per tier.
base: int, digit base B (default 4). H must be a power of B.
filter_width: int, number of passing digit values per tier (default B//2). Constant across tiers to avoid mode collapse at deep tiers.
seed: int, PRNG seed for generating shift coefficients (default 42).
n_rules: int, number of rules K (default 1). Selector partitions tier-0-passing states into K buckets; uniform partition requires K | f^d_active.
head_seed: int, PRNG seed for per-rule head shift coefficients (default: same as seed; ensures n_rules=1 reproduces single-rule).
cross_dim_mods: Optional[list[int|None]], per-tier modular cross-dim constraint. m_t must divide filter_width for exact mode counts. Default: no cross-dim constraints.
active_dims: Optional[list[int]], subset of dims to constrain (default: all dims).
- Comparison with other compositional rewards:
BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes.
MultiplicativeCoprimeReward: prime factorization with progressive constraint types; knowledge composition — each tier requires understanding the prior tier’s structure.
This class: conditional hierarchy — each tier introduces a constraint whose form depends on what was learned at prior tiers. Coarse-to-fine ordering creates distance-correlated difficulty.
- Parameters:
height (int)
ndim (int)
- R0: float¶
- _RULE_SEED_STRIDE: int = 1000003¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _extract_digits(x, num_levels)¶
Extract base-B digits from x.
- Parameters:
x (torch.Tensor) – (…, m) integer tensor with coordinate values.
num_levels (int) – how many digit levels to extract.
- Returns:
List of num_levels tensors, each (…, m), with digit values in [0, B). digits[k] is the k-th digit (scale k), i.e., floor(x / B^k) mod B.
- Return type:
list[torch.Tensor]
- _selector(msd_digits)¶
Map state MSDs (…, d_active) -> rule index in [0, n_rules).
Applies the same shift used by tier 0’s filter (sigma_0 = 0 + filter_shift[0]) before packing as base-f. For trunk-passing states the shifted MSD is in [0, f) so the packing is bijective on the trunk-passing set; non-trunk-passing states still get a rule index but cannot become modes (their tier-0 check fails).
- Parameters:
msd_digits (torch.Tensor)
- Return type:
torch.Tensor
- _shift_coeffs_tensor: torch.Tensor¶
- _uniform_partition: bool = True¶
- _validate_rule_coverage()¶
Ensure every rule index is hit by at least one trunk-passing state.
Trunk = tier 0 = each active dim’s shifted MSD in [0, filter_width) AND (if cross_dim_mods[0] is set) MSD-sum mod m_0 == 0. The selector packs MSDs as base-f and mods by n_rules.
Implementation uses cyclic-uniformity: trunk-passing MSD patterns map bijectively to integers [0, f^d_active) via the base-f packing, so pat mod n_rules distributes uniformly when n_rules | n_surv. The cross-dim filter at tier 0 thins this set further; for non-trivial cross_mod_0 we sample to verify coverage. (Old enumeration over f^d_active patterns was intractable at d_active > ~10.)
Sets _uniform_partition based on whether n_rules divides the trunk-passing count evenly.
- Return type:
None
- active_dims: list[int]¶
- analytic_log_partition()¶
Compute log(Z) analytically.
Z = R0 * H^ndim + sum_t w_t * modes_t
- Return type:
float
- analytic_mode_count(tier=None, per_rule=False)¶
Compute exact mode count for a given tier (1-indexed) or all tiers.
Total mode count is invariant in n_rules — rules partition the canonical mode set rather than multiplying it.
- Parameters:
tier (int | None) – 1-indexed tier number. If None, returns count for the highest tier (most constrained).
per_rule (bool) – If True, returns the per-rule count (total // n_rules). Requires uniform partition (n_rules divides the trunk-passing pattern count); raises ValueError otherwise.
- Returns:
Number of states satisfying all constraints up to the given tier (over all rules combined if per_rule=False, per rule otherwise).
- Return type:
int
- base: int¶
- cross_dim_mods: list[int | None] = []¶
- filter_shift: list[int]¶
- filter_width: int¶
- head_seed: int¶
- mode_threshold(target_sparsity=0.1)¶
Return the reward threshold for mode counting at the adaptive tier.
States with reward >= this value are counted as modes. The tier is chosen via
mode_tier(target_sparsity)so that mode coverage adapts to dimensionality.- Parameters:
target_sparsity (float) – Passed to
mode_tier().- Returns:
Reward threshold (R0 + sum of weights up to the mode tier).
- Return type:
float
- mode_tier(target_sparsity=0.1)¶
Return the lowest tier whose mode coverage is below target_sparsity.
Coverage at tier t = (f/B)^(t*d) (before cross-dim constraints). This adapts the mode definition to dimensionality: at low d, deeper tiers are needed for modes to be sparse; at high d, even tier 1 is already a needle in a haystack.
- Parameters:
target_sparsity (float) – Fraction of total state space below which modes are considered “interesting” (default 0.10 = 10%).
- Returns:
1-indexed tier number. Clamped to [1, num_tiers].
- Return type:
int
- n_rules: int¶
- num_levels: int = 0¶
- seed: int¶
- shift_coeffs_per_rule: list[list[list[int]]] = []¶
- tier_indicators(states_tensor)¶
Per-tier independent pass/fail indicators.
Returns a list of boolean tensors (one per tier), each of shape
states_tensor.shape[:-1].indicators[t]is True for states that satisfy tier t’s digit constraint independently (not cumulatively). Note: the shift at tier t still depends on coarser digits, so the constraint is state-dependent but evaluated per-tier.- Parameters:
states_tensor (torch.Tensor)
- Return type:
list[torch.Tensor]
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.CorruptedReward(height, ndim, **kwargs)¶
Bases:
GridRewardWraps a tiered structured reward and applies per-tier corruption.
Conceptually, at each tier, a fraction
corruption_rateof states that earned that tier’s bonus have it “moved” to a random location. This degrades the compositional structure at every level proportionally.- Per-tier corruption logic:
For each tier t and each state s:
Compute the base reward’s per-tier indicator
pass_t(s).Demote: if
pass_t(s)andhash(s, seed + 2*t) < corruption_rate, remove tier t’s contribution.Promote: if not
pass_t(s)andhash(s, seed + 2*t + 1) < replacement_rate_t, add tier t’s contribution.replacement_rate_tis calibrated at init so that the expected number of promotions matches demotions.
Final reward:
R(s) = R0 + sum_t w_t * corrupted_pass_t(s)
For non-tiered base rewards, falls back to a single-level binary corruption at the mode threshold.
- Key kwargs:
base_reward: str, name of the base reward (default “conditional_multiscale”).
base_kwargs: dict, kwargs for the base reward constructor.
corruption_rate: float in [0, 1), fraction of tier-passing states to demote per tier (default 0.2).
seed: int, hash seed (default 137).
- Parameters:
height (int)
ndim (int)
- _REWARD_CLASSES: dict[str, type[GridReward]]¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _call_simple(states_tensor)¶
Fallback for non-tiered base rewards: binary corruption.
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _estimate_replacement_rates()¶
Sample states to estimate per-tier pass fraction, then set replacement rates so promotions ~ demotions in expectation.
- Return type:
None
- _is_tiered¶
- _replacement_rates: list[float] = []¶
- base_fn: GridReward¶
- base_reward_str = ''¶
- corruption_rate: float¶
- mode_threshold()¶
Return the mode threshold derived from the base reward.
- Return type:
float
- seed: int¶
- class gfn.gym.hypergrid.CosineReward(height, ndim, **kwargs)¶
Bases:
GridRewardCosine reward function.
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.DeceptiveReward(height, ndim, **kwargs)¶
Bases:
GridRewardDeceptive reward function from Adaptive Teachers (Kim et al., 2025).
Note that the reward definition in the paper (eq. (9)) is incorrect, and we follow the official implementation (https://github.com/alstn12088/adaptive-teacher/blob/8cfcb2298fce3f46eb36ead03791eeee75b7d066/grid/env.py#L27) while modifying it to use EPS = 1e-12 to handle inequalities with floating points.
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.GridReward(height, ndim, **kwargs)¶
Bases:
abc.ABCBase class for reward functions that can be pickled.
- Parameters:
height (int)
ndim (int)
- _EPS = 1e-12¶
- abstract __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- height¶
- kwargs¶
- ndim¶
- class gfn.gym.hypergrid.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)¶
Bases:
gfn.env.DiscreteEnvHyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.
- Parameters:
ndim (int)
height (int)
reward_fn_str (str)
reward_fn_kwargs (dict | None)
device (Literal['cpu', 'cuda'] | torch.device)
calculate_partition (bool)
store_all_states (bool)
debug (bool)
validate_modes (bool)
mode_stats (Literal['none', 'approx', 'exact'])
mode_stats_samples (int)
- ndim¶
The dimension of the grid.
- height¶
The height of the grid.
- reward_fn¶
The reward function.
- calculate_partition¶
Whether to calculate the log partition function.
- store_all_states¶
Whether to store all states.
- validate_modes¶
Whether to check that at least one state reaches the mode threshold at init; raises if not.
- mode_stats¶
One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.
- mode_stats_samples¶
Number of random samples when mode_stats=”approx”.
- States: type[gfn.states.DiscreteStates]¶
- _all_states_tensor = None¶
- _enumerate_all_states_tensor(batch_size=20000)¶
Enumerate all grid states, optionally storing them and computing log Z.
Iterates over the full Cartesian product
{0, ..., H-1}^Din batches (via multiprocessing) to avoid materializing allH^Dstates at once.- Parameters:
batch_size (int) – Number of states per batch.
- _exists_bitwise_xor(thr)¶
Deterministic feasibility check for
BitwiseXORReward.Builds the combined GF(2) system [trunk; selector→0; head_0]·b = [c_trunk; 0; c_head_0] for rule 0 and verifies consistency. This works uniformly for n_rules=1 (k_select=0, selector empty) and for K-rule (per-rule coverage was already verified at __init__ by _validate_rule_coverage; this re-checks rule 0 as a defense-in-depth).
Feasibility of this combined GF(2) system is necessary and sufficient for a mode to exist; no random sampling is required. Presets use power-of-two heights so every feasible bit-assignment is a valid state (raw coord < height).
- Parameters:
thr (float)
- Return type:
bool
- _exists_conditional_multiscale(thr)¶
Constructive existence check for ConditionalMultiScaleReward.
With filter_shift=[0,…,0] (default) the all-zeros state is always a mode: every per-tier filter passes 0 since (0 + 0) mod B = 0 < f. With non-zero filter_shift, we try a few “all-same-v” candidate states chosen so the MSD passes tier 0; one of them typically passes all deeper tiers when the per-rule shift_coeffs map zero lower digits to zero, leaving the constant filter_shift[t] as the only contribution.
- Parameters:
thr (float)
- Return type:
bool
- _exists_cosine(thr)¶
Analytic upper-bound check for
CosineReward.Idea: - The per-dimension factor is
(cos(50·ax) + 1) · N(0,1)(5·ax)withax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value
m.The full reward upper bound is
R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by
mode_gamma.
- Parameters:
thr (float)
- Return type:
bool
- _exists_fallback_random(thr)¶
Random sampling fallback.
Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.
- Parameters:
thr (float)
- Return type:
bool
- _exists_multiplicative_coprime(thr)¶
Number-theoretic constructive check for
MultiplicativeCoprimeReward.For each rule, factors the rule’s target LCM over allowed primes, tries permutations of prime-to-active-dim assignments, and checks coprime + grid-bound + selector-match. Returns True iff at least one rule has a witness state whose selector maps back to that rule’s index AND whose reward reaches the mode threshold.
The reward shifts raw coords by +1 internally (raw 0 → internal 1), so witness states are constructed in raw space as
p**exp - 1per active dim, with coprime pair checks evaluated on the post-shift internal values.At n_rules=1 the selector is trivially 0 and only rule 0 is tried, recovering the legacy behavior.
- Parameters:
thr (float)
- Return type:
bool
- _exists_original_or_deceptive(thr)¶
Constructive check for
OriginalRewardandDeceptiveReward.Intuition: - These rewards form rings/bands around the center when each coordinate
is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.
We translate those fractional band boundaries into integer indices via small inside/outside nudges (using
EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.If the reward at that candidate exceeds the threshold (with
EPS_REWARD_CMPtolerance), we return True.
- Parameters:
thr (float)
- Return type:
bool
- _exists_random_or_corrupted(thr)¶
Check for UniformRandomReward or CorruptedReward.
For UniformRandomReward the probe budget is sized so that P(miss all modes | at least one mode exists) < 1e-9, using n = ceil(log(1e-9) / log(1 - mode_prob)). For CorruptedReward a fixed budget of 10 000 is used (mode density is approximately preserved by the promotion/demotion calibration).
A seeded generator derived from the reward seed and grid shape makes the result reproducible across calls with the same configuration.
- Parameters:
thr (float)
- Return type:
bool
- _exists_sparse(thr)¶
Constructive check for
SparseReward.This reward assigns positive mass only to a finite set of target configurations. When
H>=2andD>=1, a known target is the zero vector except for certain coordinates fixed at 1 orH-2. We probe a canonical target and confirm the threshold is not above its reward.- Parameters:
thr (float)
- Return type:
bool
- _generate_combinations_in_batches(ndim, max_val, batch_size)¶
Yield batches of the Cartesian product {0, …, max_val}^ndim.
Uses multiprocessing to avoid materializing the full product (size
(max_val+1)^ndim) in memory.Workers are created via the spawn start method and execute the module-level
_hypergrid_worker()function so the call is safe inside MPI ranks and CUDA contexts (see the start-method comment near the top of this file). Pool size is capped atMAX_POOL_WORKERSbecause larger pools just multiply per-rank fork/spawn overhead without shrinking the per-task work — and a 64-core node hosting many co-located MPI ranks can otherwise blow up to thousands of worker processes simultaneously.- Parameters:
ndim (int) – Number of dimensions (tuple length).
max_val (int) – Maximum coordinate value (inclusive).
batch_size (int) – Number of tuples per batch.
- Yields:
A list of tuples for each batch.
- _get_states_indices_bigint(states_raw)¶
Compute canonical indices using arbitrary-precision Python ints.
Used by
get_states_indices()whenheight ** ndim > 2 ** 63and the int64 path would overflow.Vectorized over the (potentially large) batch dimension via numpy object-dtype broadcasting: the inner Python loop iterates only over the small feature dimension
ndim, and eachk * h + coloperation dispatches a single C-level loop over all rows that calls Pythonint.__mul__/int.__add__per element. This is a few times faster than a nested Python loop while still preserving arbitrary-precision correctness.Returns a numpy
objectarray of shapestates_raw.shape[:-1]containing one Pythonintper state.- Parameters:
states_raw (torch.Tensor)
- Return type:
numpy.ndarray
- _log_partition = None¶
- _mode_reward_threshold()¶
Returns the reward threshold used to define a mode.
By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.
- Return type:
float
- _mode_stats_kind: str = 'none'¶
- _modes_exist_quick_check()¶
Lightweight check that a mode-level state exists.
In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly
and checks against the threshold.
Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.
As a last resort, it samples a small batch of random states.
- Return type:
bool
- _modes_exist_quick_check_info()¶
Same as _modes_exist_quick_check but returns (ok, message).
- Return type:
tuple[bool, str]
- _n_mode_states_estimate: float | None = None¶
- _n_mode_states_exact: int | None = None¶
- static _solve_gf2_has_solution(A, c)¶
Return True if A x = c over GF(2) has at least one solution.
Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (
0 = 1) indicates inconsistency.- Parameters:
A (torch.Tensor)
c (torch.Tensor)
- Return type:
bool
- static _solve_gf2_witness(A, c, n_vars)¶
Return a witness solution to A·b = c over GF(2), or None if none exists.
b has length n_vars. A is reduced via Gaussian elimination; free variables are set to 0.
- Parameters:
A (torch.Tensor)
c (torch.Tensor)
n_vars (int)
- Return type:
torch.Tensor | None
- _true_dist = None¶
- all_indices()¶
Generate all possible indices for the grid.
- Returns:
A list of all possible indices for the grid.
- Return type:
List[Tuple[int, Ellipsis]]
- property all_states: gfn.states.DiscreteStates | None¶
Returns a tensor of all hypergrid states as a DiscreteStates instance.
- Return type:
gfn.states.DiscreteStates | None
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- calculate_partition = False¶
- get_states_indices(states)¶
Get the canonical ordering indices for a batch of states.
Returns one canonical index per state computed from the base-
heightencodingsum(s[j] * height^(ndim-1-j)). The maximum index isheight^ndim - 1.Safe regime (
height ** ndim <= 2 ** 63): the index fits in signed int64 and we return atorch.Tensorof shapebatch_shapewith dtypetorch.int64(the historical behaviour).Overflow regime (
height ** ndim > 2 ** 63): the index would overflow int64 and silently wrap, producing collisions between distinct states (a real bug we hit at e.g. ndim=10, height=128 where128**10 == 2**70). In this regime we fall back to per-row Pythonintarithmetic and return anumpy.ndarrayof dtypeobjectcontaining arbitrary-precision Python ints. Each element is a unique, hashable canonical index.
The two return types support the same downstream usages we care about (
set(...tolist())for mode tracking, boolean masking with[mask]after converting the mask to numpy if needed). Code paths that need anint64tensor for tensor indexing (e.g.EnumPreprocessor) implicitly require the safe regime — they’ll see the numpy fallback and fail loudly, which is the correct behavior because such grids are too large to enumerate anyway.- Parameters:
states (Union[gfn.states.DiscreteStates, torch.Tensor]) – The states to get the indices of.
- Returns:
Indices in canonical ordering.
torch.Tensor[int64]of shapebatch_shapein the safe regime;np.ndarray[object]of shapebatch_shapecontaining Python ints in the overflow regime.- Return type:
Union[torch.Tensor, numpy.ndarray]
- get_terminating_states_indices(states)¶
Get the indices of the terminating states in the canonical ordering.
See
get_states_indices()for the return-type contract: atorch.Tensor[int64]for grids small enough to fit in 62 bits, or anumpy.ndarray[object]of Python ints for larger grids that would otherwise overflow.- Parameters:
states (gfn.states.DiscreteStates) – The states to get the indices of.
- Returns:
The indices of the terminating states in the canonical ordering.
- Return type:
Union[torch.Tensor, numpy.ndarray]
- height = 8¶
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float | None
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a batch of random states.
- Parameters:
batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A DiscreteStates object with random states.
- Return type:
- make_states_class()¶
Returns the DiscreteStates class for the HyperGrid environment.
- Return type:
- mode_mask(states)¶
Boolean mask indicating which states are in a mode.
A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).
- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
torch.Tensor
- modes_found(states)¶
Returns the set of canonical state indices for mode states in the batch.
Each mode state is identified by its unique canonical index (from
get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
set[int]
- property n_mode_states: int | float | None¶
Number of states inside a mode (exact, approx, or None).
If mode_stats=”exact”, returns an exact integer count.
If mode_stats=”approx”, returns a floating-point estimate.
If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.
Otherwise, returns None.
- Return type:
int | float | None
- property n_modes: int | float | None¶
Returns the total number of mode states for this environment.
Equivalent to
n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.- Return type:
int | float | None
- property n_states: int¶
Returns the number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states in the environment.
- Return type:
int
- ndim = 2¶
- reward(states)¶
Computes the reward for a batch of final states.
In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}
0.5 rightrvert in (0.25, 0.5] right)
2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`
- Parameters:
final_states – The final states.
states (gfn.states.DiscreteStates)
- Returns:
The reward of the final states.
- Return type:
torch.Tensor
- reward_fn¶
- reward_fn_kwargs = None¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- store_all_states = False¶
- property terminating_states: gfn.states.DiscreteStates | None¶
Returns all terminating states of the environment.
- Return type:
gfn.states.DiscreteStates | None
- true_dist(condition=None)¶
Returns the pmf over all states in the hypergrid.
- Return type:
torch.Tensor | None
- class gfn.gym.hypergrid.MultiplicativeCoprimeReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered reward based on prime-support and coprimality/lcm composition.
- Curriculum motivation — knowledge composition:
This reward tests whether a GFlowNet can learn number-theoretic structure progressively: first discovering which coordinates factor over allowed primes (tier 0), then learning exponent bounds (tier 1), then cross- dimensional coprimality (tier 2), and finally global LCM targets (tier 3).
Each tier builds on knowledge from prior tiers: learning prime factorization at tier 0 is prerequisite for reasoning about exponent caps at tier 1, which in turn enables the coprimality reasoning needed at tier 2. This tests compositional transfer where each level requires a qualitatively different type of constraint, not just more of the same.
Coordinates are shifted by +1 internally (state 0 -> value 1) so that the origin is valid and short trajectories immediately encounter small prime-factorable numbers (2, 3, 4, 5, 6, …), providing early training signal for long-horizon credit assignment.
- Each tier progressively adds new constraint types:
Tier 0: Prime support — coordinates must factor over allowed primes.
Tier 1+: Exponent caps — prime exponents bounded per tier.
coprime_start_tier+: Coprime pairs — cross-dimensional coupling.
target_lcms: LCM targets — global compositional constraint.
- Reward form:
R(s) = R0 + Σ_t tier_weights[t] · 1[ constraints_0..t all satisfied ]
- Key kwargs:
R0: float, base reward (default 0.0)
tier_weights: list[float]
primes: list[int], e.g., [2,3,5,7,11]. Primes exceeding height are auto-filtered with a warning.
exponent_caps: list[int], same length as tier_weights. Cap for every prime at tier t (uniform cap across primes for simplicity). Auto-capped to floor(log_p(height)) for each prime p.
active_dims: Optional[list[int]]; constraints only apply to these dims (default: all dims). Other dims are ignored in constraints.
coprime_pairs: Optional[list[tuple[int,int]]]; indices relative to active_dims.
coprime_start_tier: int, first tier at which coprime constraints apply (default: 0, preserving backward compatibility).
target_lcms: Optional[list[int | None | str]]; per-tier target lcm across active dims. Use “auto” to derive from primes and exponent_caps.
Notes: - Coordinates are shifted by +1 internally: state value 0 maps to reward
value 1, making the origin (0,…,0) trivially valid.
Implementation removes primes up to the current tier cap and checks residue == 1. Exponent counts are accumulated to evaluate LCM targets.
- Comparison with other compositional rewards:
BitwiseXORReward: GF(2) parity checks on bit-planes; rule reuse — same parity check type with increasing strictness. Non-local modes.
ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; conditional hierarchy — coarse-scale structure predicts fine-scale constraints.
This class: prime factorization with progressive constraint types (support -> caps -> coprimality -> LCM). Knowledge composition — each tier requires understanding the prior tier’s structure.
- Parameters:
height (int)
ndim (int)
- R0: float¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _factor_exponents_up_to_cap(v, cap)¶
Trial-divide each element by allowed primes, returning residue and exponents.
- Parameters:
v (torch.Tensor) – (…,) LongTensor of non-negative values to factorize.
cap (int) – Maximum number of times each prime may divide a value.
- Returns:
(…,) values after stripping allowed primes (1 if fully factored). exps: [num_primes, …] exponent counts per prime (leading axis is primes).
- Return type:
residue
- _generate_rule_targets(n_rules, seed)¶
Generate n_rules distinct LCM targets from a deterministic enum.
Enumerates cap-tuples (cap_p ∈ {1, …, top_cap}) over allowed primes — cap=0 (prime absent from LCM) is excluded so every rule has a non-trivial target. Total tuples = top_cap^n_primes; permutes by seed and picks the first n_rules. The prime universe must satisfy top_cap^n_primes >= n_rules.
Note: cap=0 is intentionally excluded. With cap=0, the LCM head constraint “no active dim has prime p as factor” combined with coprime-pair and exponent-cap constraints typically yields zero or near-zero modes — degenerate rules.
- Parameters:
n_rules (int)
seed (int)
- Return type:
list[int]
- _lcm_ok(exps, target_lcm)¶
Check whether max exponents across dims match target LCM’s factorization.
- Parameters:
exps (torch.Tensor) – [num_primes, …, num_active_dims] exponent counts.
target_lcm (int) – The target LCM value to match.
- Returns:
(…,) bool mask, True where the LCM of active-dim values equals target.
- Return type:
torch.Tensor
- _pairwise_coprime_ok(v)¶
Check that configured dimension pairs share no common allowed prime.
- Parameters:
v (torch.Tensor) – (…, num_active_dims) coordinate values.
- Returns:
(…,) bool mask, True where all coprime pair constraints hold.
- Return type:
torch.Tensor
- _rule_target_exps¶
- _selector(x_active)¶
Map state’s active-dim values to a rule index in [0, n_rules).
Selector is sum_i x_i mod n_rules (over active dims, shifted values). Returns shape (…,) long tensor.
- Parameters:
x_active (torch.Tensor)
- Return type:
torch.Tensor
- _validate_rule_coverage()¶
Ensure each rule’s LCM target is achievable.
A target is achievable iff (a) it factors over allowed primes (already checked at __init__), and (b) every required exponent is <= top cap. Both conditions are necessary; with active_dims >= n_primes_with_nonzero_exp and coprime_pairs constraints, sufficiency requires enumeration. Here we check only the necessary conditions and verify that the rules produce non-degenerate distinct targets.
- Return type:
None
- active_dims: list[int]¶
- coprime_pairs¶
- coprime_start_tier: int¶
- exponent_caps: list[int] = []¶
- head_seed: int¶
- n_rules: int¶
- primes: list[int]¶
- rule_targets: list[int | None]¶
- target_lcms: list[int | None] = []¶
- tier_indicators(states_tensor)¶
Per-tier independent pass/fail indicators.
Returns a list of boolean tensors (one per tier), each of shape
states_tensor.shape[:-1].indicators[t]is True for states that satisfy tier t’s constraints independently (not cumulatively).- Parameters:
states_tensor (torch.Tensor)
- Return type:
list[torch.Tensor]
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.OriginalReward(height, ndim, **kwargs)¶
Bases:
GridRewardThe reward function from the original GFlowNet paper (Bengio et al., 2021; https://arxiv.org/abs/2106.04399).
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.SparseReward(height, ndim, **kwargs)¶
Bases:
GridRewardSparse reward function from the GAFN paper (Pan et al., 2022; https://arxiv.org/abs/2210.03308).
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- targets¶
- class gfn.gym.hypergrid.UniformRandomReward(height, ndim, **kwargs)¶
Bases:
GridRewardEach state is independently a mode with probability
mode_prob.Uses a deterministic hash on state coordinates so mode membership is reproducible without storing or enumerating all states. There is no exploitable spatial or algebraic structure.
Reward form:
R(s) = R0 + R_mode if hash(s, seed) < mode_prob R(s) = R0 otherwise
- Key kwargs:
R0: float, base reward for non-mode states (default 0.1).
R_mode: float, additional reward for mode states (default 2.0).
mode_prob: float in (0, 1), probability each state is a mode (default 0.01).
seed: int, hash seed for reproducibility (default 42).
- Parameters:
height (int)
ndim (int)
- R0: float¶
- R_mode: float¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- mode_prob: float¶
- seed: int¶
- gfn.gym.hypergrid._MAX_POOL_WORKERS = 8¶
- gfn.gym.hypergrid._first_k_dims(k, ndim)¶
Return indices [0, 1, …, min(k, ndim)-1] for the first k dimensions.
- Parameters:
k (int)
ndim (int)
- Return type:
list[int]
- gfn.gym.hypergrid._gf2_random_fullrank(n_checks, n_vars, seed)¶
Generate a full-rank random binary matrix A and target vector c over GF(2).
Uses a deterministic seed for reproducibility across runs.
- Parameters:
n_checks (int) – Number of independent GF(2) equations (rows of A).
n_vars (int) – Number of binary variables (columns of A).
seed (int) – Deterministic RNG seed.
- Returns:
(A, c) where A is (n_checks, n_vars) int tensor and c is (n_checks,) int tensor, both with values in {0, 1}. A is guaranteed to have full row-rank over GF(2).
- Return type:
tuple[torch.Tensor, torch.Tensor]
- gfn.gym.hypergrid._gf2_rank(A)¶
Compute the rank of a binary matrix over GF(2).
- Parameters:
A (torch.Tensor)
- Return type:
int
- gfn.gym.hypergrid._hypergrid_worker(task)¶
Module-level worker for
HyperGrid._generate_combinations_in_batches.Returns the requested slice of the Cartesian product as a concrete
list. Lives at module level (rather than as a bound method) so it can be pickled to a spawnedmultiprocessing.Poolworker — bound methods ofHyperGridare not picklable because the env’s States subclass is created locally insidemake_states_class.- Parameters:
task –
(values, ndim, start_idx, end_idx)wherevaluesis the list of coordinate values,ndimis the number of dimensions, and[start_idx, end_idx)is the index range within the full Cartesian product.- Returns:
A list of length
end_idx - start_idxcontaining tuples of lengthndim. Returning a concrete list (rather than anitertools.islice) keeps the result picklable across workers and future-proofs against the Python 3.14 removal of itertools pickle support.
- gfn.gym.hypergrid._preset_seed(name)¶
Deterministic seed from a preset name.
- Parameters:
name (str)
- Return type:
int
- gfn.gym.hypergrid._state_hash_uniform(states_tensor, seed)¶
Deterministic hash mapping each grid state to a float in [0, 1).
Uses a polynomial rolling hash over coordinate values computed in int64 arithmetic. Suitable for pseudo-random but reproducible per-state decisions (e.g., mode assignment, corruption masks).
- Parameters:
states_tensor (torch.Tensor) – (…, ndim) integer tensor of coordinates.
seed (int) – Integer seed for determinism.
- Returns:
-1] with values in [0.0, 1.0).
- Return type:
Tensor of shape states_tensor.shape[
- gfn.gym.hypergrid.get_bitwise_xor_presets(ndim, height)¶
Return five difficulty presets for BitwiseXORReward.
Difficulty is controlled by the number of constrained dimensions M. More constrained dims means more independent GF(2) checks per bit position, leading to fewer modes.
- Each preset uses 3 tiers with increasing numbers of GF(2) checks:
Tier 0 (curriculum): few checks, many states pass
Tier 1 (intermediate): moderate checks
Tier 2 (mode): strictest checks, defines the modes
Mode counts for ndim=10, height=16 (B=4 bit-planes). Per-bit-position solutions = 2^(M − cum_top_checks); raised to the B-th power, then multiplied by 16^(ndim − M) free-dim configurations:
easy (M=3, cum=2): ~4.3B modes (16 × 16^7)
medium (M=5, cum=4): ~16.8M modes (16 × 16^5)
hard (M=8, cum=6): ~65K modes (256 × 16^2)
challenging (M=10, cum=7): ~4K modes (4096 × 1)
impossible (M=12→10, cum=9): 16 modes (16 × 1; M capped to ndim)
Notes - Uses fixed seeds per preset name for reproducibility. - Parity checks are random full-rank GF(2) matrices. - Bit ranges are capped to ceil(log2(height)) - 1.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_conditional_multiscale_presets(ndim, height)¶
Return difficulty presets for ConditionalMultiScaleReward.
All presets use base=4 (requiring H to be a power of 4). The number of available digit levels is L = log_4(H). Difficulty is controlled by:
Number of tiers (more tiers = deeper compositional hierarchy)
Number of active dims (more = exponentially sparser modes)
Cross-dim modular constraints (further sparsification)
- Mode counts are computed via:
modes_T = (f^T * 4^{L-T})^d * H^(ndim-d) / prod_t m_t
with f = filter_width = 2 (i.e. B//2), so each tier halves modes per coord.
Digit ordering is coarse-to-fine: tier 0 constrains the most significant digit. Near the origin (small coordinates), high digits are 0 and trivially pass the filter. Deeper tiers constrain progressively finer digits, creating distance-correlated difficulty.
Two preset families:
- Difficulty presets (single-rule legacy, assuming H=256, i.e. L=4):
easy: 2 tiers, 3 active dims -> ~2M modes at tier 2
medium: 3 tiers, 4 active dims -> ~1M modes at tier 3
hard: 3 tiers, 6 active dims, cross-dim -> ~260K modes at tier 3
challenging: 4 tiers, 8 active dims, cross-dim -> ~65K modes at tier 4
impossible: 4 tiers, 12 active dims, cross-dim -> ~4K modes at tier 4
- K-rule trunk+heads presets (sparsity matched, K rules sharing tier-0 trunk):
K1, K16, K64: 3 tiers, 6 active dims, cross_dim_mods=[2,2,2]. Total modes invariant in K (~32K total at H=64, density ~5e-7); rules partition the canonical mode set. Designed for ndim=6, H=64.
If height provides fewer digit levels than a preset requires, the preset’s tier_weights and cross_dim_mods are auto-truncated with a warning.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_corrupted_presets(ndim, height)¶
Return five difficulty presets for CorruptedReward.
Each preset wraps a
conditional_multiscale“medium” base and applies increasing corruption. A singlecorruption_rateparameter controls the fraction of per-tier structure that is randomized.- Difficulty progression:
easy: 10% corruption -> mostly structured
medium: 30% corruption -> noticeable randomness
hard: 50% corruption -> half structured, half random
challenging: 70% corruption -> mostly random
impossible: 90% corruption -> near-total randomness
Note: requires
heightto be a power of 4 (same as the base reward).- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_cosine_presets(ndim, height)¶
Return five presets for CosineReward.
R1 scales the oscillatory product, and mode_gamma (used only for mode detection thresholding) tightens what is considered a “mode-like” maximum.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_deceptive_presets(ndim, height)¶
Return five presets for DeceptiveReward.
Increase R2 to accentuate the thin band, and set a small but non-zero R0. R1 controls the center emphasis vs. the cancelled outer region.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_multiplicative_coprime_presets(ndim, height)¶
Return five difficulty presets for MultiplicativeCoprimeReward.
Each preset uses progressive tier structure where each tier adds a new constraint type:
Tier 0: Prime support only (coords must factor over allowed primes)
Tier 1: + Exponent caps (tighten factorization)
coprime_start_tier+: + Coprime pairs (cross-dim coupling)
Final tier: + LCM target (global compositional constraint)
Coordinates are shifted +1 internally (origin -> all-ones), so short trajectories immediately encounter small prime-factorable numbers.
Primes exceeding height and exponent caps exceeding log_p(height) are auto-filtered/capped in the reward constructor.
Notes - active_dims indexes are relative to state dims; we pick first k. - coprime_pairs are pairs within active_dims index space. - Tier weights are geometric. - Use target_lcms=”auto” to derive from primes and exponent_caps.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_original_presets(ndim, height)¶
Return five presets for OriginalReward.
These presets primarily control the relative importance of the outer ring (R1) and thin band (R2). Exploration difficulty (distance from s0) is more a function of (D, H) than of these weights; tune D and H externally to match your distance bands.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_reward_presets(reward_fn_str, ndim, height)¶
Return presets for a given reward function name.
Usage¶
presets = get_reward_presets(“bitwise_xor”, D, H) kwargs = presets[“hard”] env = HyperGrid(ndim=D, height=H, reward_fn_str=”bitwise_xor”, reward_fn_kwargs=kwargs)
- Parameters:
reward_fn_str (str)
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_sparse_presets(ndim, height)¶
Return five presets for SparseReward.
SparseReward has built-in targets; it ignores most kwargs. Presets are provided for API symmetry and future extensibility.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_uniform_random_presets(ndim, height)¶
Return five difficulty presets for UniformRandomReward.
Difficulty is controlled by
mode_prob: lower probability means sparser modes, which are harder for GFlowNets to discover.- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.lcm(a, b)¶
Returns the lowest common multiple between a and b.
- gfn.gym.hypergrid.lcm_multiple(numbers)¶
Find the lowest common multiple across a list of numbers
- gfn.gym.hypergrid.logger¶
- gfn.gym.hypergrid.smallest_multiplier_to_integers(float_vector, precision=3)¶
Used to calculate a scale factor to avoid imprecise floating point arithmetic.