gfn.gym.hypergrid¶
Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
Attributes¶
Classes¶
Tiered, compositional reward based on bitwise XOR/parity constraints. |
|
HyperGrid environment with condition-aware rewards. |
|
Tiered reward via conditional digit constraints across spatial scales. |
|
Cosine reward function. |
|
Deceptive reward function from Adaptive Teachers (Kim et al., 2025). |
|
Base class for reward functions that can be pickled. |
|
HyperGrid environment from the GFlowNets paper. |
|
Tiered reward based on prime-support and coprimality/lcm composition. |
|
The reward function from the original GFlowNet paper (Bengio et al., 2021; |
|
Sparse reward function from the GAFN paper (Pan et al., 2022; |
Functions¶
|
Return indices [0, 1, ..., min(k, ndim)-1] for the first k dimensions. |
|
Return five difficulty presets for BitwiseXORReward. |
|
Return five difficulty presets for ConditionalMultiScaleReward. |
|
Return five presets for CosineReward. |
|
Return five presets for DeceptiveReward. |
|
Return five difficulty presets for MultiplicativeCoprimeReward. |
|
Return five presets for OriginalReward. |
|
Return presets for a given reward name: 'bitwise_xor', 'multiplicative_coprime', 'conditional_multiscale'. |
|
Return five presets for SparseReward. |
|
Returns the lowest common multiple between a and b. |
|
Find the lowest common multiple across a list of numbers |
|
Used to calculate a scale factor to avoid imprecise floating point arithmetic. |
Module Contents¶
- class gfn.gym.hypergrid.BitwiseXORReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered, compositional reward based on bitwise XOR/parity constraints.
This class implements the “Bitwise/XOR fractal” environment family: where tiers progressively constrain bit-planes across a subset of dimensions via linear parity checks over GF(2). It supports easy sharding by high-bit prefixes, and difficulty control by adjusting which bit-planes and how many dimensions are constrained per tier.
GF(2) is the finite field with two elements {0, 1}, where addition and multiplication are performed modulo 2. In this context, vector addition is equivalent to bitwise XOR, and matrix-vector products (A @ b) are evaluated entrywise modulo 2.
- Reward form:
R(s) = R0 + Σ_t tier_weights[t] · 1[ state satisfies all constraints up to tier t ]
- Key kwargs (with reasonable defaults):
R0: float, base reward (default 0.0)
tier_weights: list[float], strictly increasing weights for each tier
dims_constrained: Optional[list[int]] subset of dims to constrain (default: all dims)
bits_per_tier: list[tuple[int,int]]; for each tier t, inclusive bit range (low_bit, high_bit). Example: [(0,5), (0,7), (0,9)].
- parity_checks: Optional[list[dict]]; per tier, optional parity system:
- Each entry may contain:
{ “A”: IntTensor[num_checks, m], “c”: IntTensor[num_checks] }
where m = len(dims_constrained). Constraints apply identically to every bit-plane specified for that tier: A @ b(mod2) == c, where b are the bit values across constrained dimensions at the tested bit-plane. If omitted for a tier, a single even-parity check across all constrained dims is used by default: sum(b) mod 2 == 0.
Difficulty presets align with step ranges by controlling the highest bit used and the number of constrained dimensions. Typical distance from origin for valid modes scales roughly like (constrained_dims · 2^{highest_bit}).
- Comparison with other compositional rewards:
MultiplicativeCoprimeReward: number-theoretic (prime factorization); same constraint type per tier (tighter exponent caps), no cross-scale dependency.
ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; each tier’s rule is a function of prior tiers, introducing qualitatively different structure at each level.
This class: GF(2) linear algebra on bit-planes; same parity check type per tier, but applied to progressively wider bit windows. Constraints at each bit-plane are independent of other planes.
- Parameters:
height (int)
ndim (int)
- R0: float¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _apply_parity_checks(bits_plane, tier_idx)¶
Apply GF(2) linear parity checks at a single bit-plane.
bits_plane: (…, m) with m=len(dims_constrained), integer in {0,1}. Returns mask (…,) bool.
- Parameters:
bits_plane (torch.Tensor)
tier_idx (int)
- Return type:
torch.Tensor
- _even_parity_mask(bits)¶
bits: (…, m) int/bool -> returns (…,) bool for even parity.
- Parameters:
bits (torch.Tensor)
- Return type:
torch.Tensor
- bits_per_tier: list[tuple[int, int]]¶
- dims_constrained: list[int]¶
- parity_checks¶
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.ConditionalHyperGrid(*args, **kwargs)¶
Bases:
HyperGridHyperGrid environment with condition-aware rewards.
Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)
- _log_partition_cache: dict[torch.Tensor, float]¶
- _max_reward: float¶
- _original_reward_fn¶
- _true_dist_cache: dict[torch.Tensor, torch.Tensor]¶
- condition_dim: int = 1¶
- is_conditional: bool = True¶
- log_partition(condition)¶
Compute the log partition for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)
- Returns:
The log partition function, as a float.
- Return type:
float
- reward(states)¶
Compute rewards for the conditional environment.
A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original
- Parameters:
states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)
- Returns:
A tensor of shape (*batch_shape,) containing the rewards.
- Return type:
torch.Tensor
- sample_conditions(batch_shape)¶
Sample conditions for the environment.
- Parameters:
batch_shape (int | tuple[int, Ellipsis])
- Return type:
torch.Tensor
- true_dist(condition)¶
Compute the true distribution for the given condition.
- Parameters:
condition (torch.Tensor) – The condition to compute the true distribution for.
be (condition.shape should)
- Returns:
The true distribution for the given condition as a 1-dimensional tensor.
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.ConditionalMultiScaleReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered reward via conditional digit constraints across spatial scales.
Each coordinate is decomposed in base B into L = log_B(H) digits. Tier t constrains digit t-1 via a shifted filter that depends on all finer-scale digits, creating a hierarchy where learning lower-scale structure is prerequisite for predicting higher-scale constraints.
- Per-dimension constraint at tier t:
(d_{t-1}(i) + sigma_t(i)) mod B < f_t
where sigma_t(i) = sum_{k=0}^{t-2} a_{t,k} * d_k(i) mod B is a linear function of finer-scale digits with seed-derived coefficients a_{t,k}.
- Optional cross-dimensional constraint at tier t:
sum_i d_{t-1}(i) ≡ 0 (mod m_t)
- Reward form (cumulative — tier t requires all tiers 1..t):
R(s) = R0 + sum_t tier_weights[t] * 1[s satisfies tiers 1..t]
- Mode count (exact closed form):
Without cross-dim: modes_T = (prod_{t=1}^T f_t)^d * B^{(L-T)*d} With cross-dim: modes_T = (prod_t f_t)^d * B^{(L-T)*d} / prod_t m_t
- Partition function (analytic, no enumeration):
Z = R0 * H^d + sum_t w_t * modes_t
- Key kwargs:
R0: float, base reward (default 0.0).
tier_weights: list[float], reward weight per tier.
base: int, digit base B (default 4). H must be a power of B.
filter_width: int, number of passing digit values per tier (default B//2). Constant across tiers to avoid mode collapse at deep tiers.
seed: int, PRNG seed for generating shift coefficients (default 42).
cross_dim_mods: Optional[list[int|None]], per-tier modular cross-dim constraint. m_t must divide filter_width for exact mode counts. Default: no cross-dim constraints.
active_dims: Optional[list[int]], subset of dims to constrain (default: all dims).
- Comparison with other compositional rewards:
BitwiseXORReward: GF(2) parity checks on bit-planes; each tier widens the bit window but uses the same rule type. No cross-scale dependency.
MultiplicativeCoprimeReward: prime factorization with tightening exponent caps. Same constraint type at every tier.
This class: each tier introduces a qualitatively different constraint whose form depends on what was learned at prior tiers (conditional structure across scales).
- Parameters:
height (int)
ndim (int)
- R0: float¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _extract_digits(x, num_levels)¶
Extract base-B digits from x.
- Parameters:
x (torch.Tensor) – (…, m) integer tensor with coordinate values.
num_levels (int) – how many digit levels to extract.
- Returns:
List of num_levels tensors, each (…, m), with digit values in [0, B). digits[k] is the k-th digit (scale k), i.e., floor(x / B^k) mod B.
- Return type:
list[torch.Tensor]
- active_dims: list[int]¶
- analytic_log_partition()¶
Compute log(Z) analytically.
Z = R0 * H^d + sum_t w_t * modes_t
- Return type:
float
- analytic_mode_count(tier=None)¶
Compute exact mode count for a given tier (1-indexed) or all tiers.
- Parameters:
tier (int | None) – 1-indexed tier number. If None, returns count for the highest tier (most constrained).
- Returns:
Number of states satisfying all constraints up to the given tier.
- Return type:
int
- base: int¶
- cross_dim_mods: list[int | None] = []¶
- filter_width: int¶
- mode_threshold(target_sparsity=0.1)¶
Return the reward threshold for mode counting at the adaptive tier.
States with reward >= this value are counted as modes. The tier is chosen via
mode_tier(target_sparsity)so that mode coverage adapts to dimensionality.- Parameters:
target_sparsity (float) – Passed to
mode_tier().- Returns:
Reward threshold (R0 + sum of weights up to the mode tier).
- Return type:
float
- mode_tier(target_sparsity=0.1)¶
Return the lowest tier whose mode coverage is below target_sparsity.
Coverage at tier t = (f/B)^(t*d) (before cross-dim constraints). This adapts the mode definition to dimensionality: at low d, deeper tiers are needed for modes to be sparse; at high d, even tier 1 is already a needle in a haystack.
- Parameters:
target_sparsity (float) – Fraction of total state space below which modes are considered “interesting” (default 0.10 = 10%).
- Returns:
1-indexed tier number. Clamped to [1, num_tiers].
- Return type:
int
- num_levels: int = 0¶
- seed: int¶
- shift_coeffs: list[list[int]] = []¶
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.CosineReward(height, ndim, **kwargs)¶
Bases:
GridRewardCosine reward function.
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.DeceptiveReward(height, ndim, **kwargs)¶
Bases:
GridRewardDeceptive reward function from Adaptive Teachers (Kim et al., 2025).
Note that the reward definition in the paper (eq. (9)) is incorrect, and we follow the official implementation (https://github.com/alstn12088/adaptive-teacher/blob/8cfcb2298fce3f46eb36ead03791eeee75b7d066/grid/env.py#L27) while modifying it to use EPS = 1e-12 to handle inequalities with floating points.
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.GridReward(height, ndim, **kwargs)¶
Bases:
abc.ABCBase class for reward functions that can be pickled.
- Parameters:
height (int)
ndim (int)
- _EPS = 1e-12¶
- abstract __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- height¶
- kwargs¶
- ndim¶
- class gfn.gym.hypergrid.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)¶
Bases:
gfn.env.DiscreteEnvHyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.
- Parameters:
ndim (int)
height (int)
reward_fn_str (str)
reward_fn_kwargs (dict | None)
device (Literal['cpu', 'cuda'] | torch.device)
calculate_partition (bool)
store_all_states (bool)
debug (bool)
validate_modes (bool)
mode_stats (Literal['none', 'approx', 'exact'])
mode_stats_samples (int)
- ndim¶
The dimension of the grid.
- height¶
The height of the grid.
- reward_fn¶
The reward function.
- calculate_partition¶
Whether to calculate the log partition function.
- store_all_states¶
Whether to store all states.
- validate_modes¶
Whether to check that at least one state reaches the mode threshold at init; raises if not.
- mode_stats¶
One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.
- mode_stats_samples¶
Number of random samples when mode_stats=”approx”.
- States: type[gfn.states.DiscreteStates]¶
- _all_states_tensor = None¶
- _enumerate_all_states_tensor(batch_size=20000)¶
Enumerate all grid states, optionally storing them and computing log Z.
Iterates over the full Cartesian product
{0, ..., H-1}^Din batches (via multiprocessing) to avoid materializing allH^Dstates at once.- Parameters:
batch_size (int) – Number of states per batch.
- _exists_bitwise_xor(thr)¶
Feasibility and constructive check for
BitwiseXORReward.Steps: - For each tier, verify the GF(2) parity system has at least one
solution using Gaussian elimination modulo 2. If any tier is infeasible, no mode exists.
The all-zero configuration satisfies even-parity constraints, so if tiers are feasible we evaluate that point against the threshold with tolerance.
- Parameters:
thr (float)
- Return type:
bool
- _exists_cosine(thr)¶
Analytic upper-bound check for
CosineReward.Idea: - The per-dimension factor is
(cos(50·ax) + 1) · N(0,1)(5·ax)withax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value
m.The full reward upper bound is
R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by
mode_gamma.
- Parameters:
thr (float)
- Return type:
bool
- _exists_fallback_random(thr)¶
Random sampling fallback.
Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.
- Parameters:
thr (float)
- Return type:
bool
- _exists_multiplicative_coprime(thr)¶
Number-theoretic constructive check for
MultiplicativeCoprimeReward.Constructs a candidate state by factoring the target LCM (if any) over the allowed primes, assigning each prime power to a separate active dimension, and verifying coprimality and grid-bound constraints.
- Parameters:
thr (float)
- Return type:
bool
- _exists_original_or_deceptive(thr)¶
Constructive check for
OriginalRewardandDeceptiveReward.Intuition: - These rewards form rings/bands around the center when each coordinate
is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.
We translate those fractional band boundaries into integer indices via small inside/outside nudges (using
EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.If the reward at that candidate exceeds the threshold (with
EPS_REWARD_CMPtolerance), we return True.
- Parameters:
thr (float)
- Return type:
bool
- _exists_sparse(thr)¶
Constructive check for
SparseReward.This reward assigns positive mass only to a finite set of target configurations. When
H>=2andD>=1, a known target is the zero vector except for certain coordinates fixed at 1 orH-2. We probe a canonical target and confirm the threshold is not above its reward.- Parameters:
thr (float)
- Return type:
bool
- _generate_combinations_in_batches(ndim, max_val, batch_size)¶
Yield batches of the Cartesian product {0, …, max_val}^ndim.
Uses multiprocessing to avoid materializing the full product (size
(max_val+1)^ndim) in memory.- Parameters:
ndim (int) – Number of dimensions (tuple length).
max_val (int) – Maximum coordinate value (inclusive).
batch_size (int) – Number of tuples per batch.
- Yields:
An iterator of tuples for each batch.
- _log_partition = None¶
- _mode_reward_threshold()¶
Returns the reward threshold used to define a mode.
By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.
- Return type:
float
- _mode_stats_kind: str = 'none'¶
- _modes_exist_quick_check()¶
Lightweight check that a mode-level state exists.
In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly
and checks against the threshold.
Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.
As a last resort, it samples a small batch of random states.
- Return type:
bool
- _modes_exist_quick_check_info()¶
Same as _modes_exist_quick_check but returns (ok, message).
- Return type:
tuple[bool, str]
- _n_mode_states_estimate: float | None = None¶
- _n_mode_states_exact: int | None = None¶
- static _solve_gf2_has_solution(A, c)¶
Return True if A x = c over GF(2) has at least one solution.
Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (
0 = 1) indicates inconsistency.- Parameters:
A (torch.Tensor)
c (torch.Tensor)
- Return type:
bool
- _true_dist = None¶
- _worker(task)¶
Return a slice of the Cartesian product for one batch.
- Parameters:
task (tuple) – (values, ndim, start_idx, end_idx) where values is the list of coordinate values, ndim is the number of dimensions, and [start_idx, end_idx) is the range within the full product.
- Return type:
itertools.islice
- all_indices()¶
Generate all possible indices for the grid.
- Returns:
A list of all possible indices for the grid.
- Return type:
List[Tuple[int, Ellipsis]]
- property all_states: gfn.states.DiscreteStates | None¶
Returns a tensor of all hypergrid states as a DiscreteStates instance.
- Return type:
gfn.states.DiscreteStates | None
- backward_step(states, actions)¶
Performs a backward step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The previous states.
- Return type:
- calculate_partition = False¶
- get_states_indices(states)¶
Get the indices of the states in the canonical ordering.
- Parameters:
states (gfn.states.DiscreteStates | torch.Tensor) – The states to get the indices of.
- Returns:
The indices of the states in the canonical ordering.
- Return type:
torch.Tensor
- get_terminating_states_indices(states)¶
Get the indices of the terminating states in the canonical ordering.
- Parameters:
states (gfn.states.DiscreteStates) – The states to get the indices of.
- Returns:
The indices of the terminating states in the canonical ordering.
- Return type:
torch.Tensor
- height = 8¶
- log_partition(condition=None)¶
Returns the log partition of the reward function.
- Return type:
float | None
- make_random_states(batch_shape, conditions=None, device=None, debug=False)¶
Creates a batch of random states.
- Parameters:
batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.
conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.
device (torch.device | None) – The device to use.
debug (bool) – If True, emit States with debug guards (not compile-friendly).
- Returns:
A DiscreteStates object with random states.
- Return type:
- make_states_class()¶
Returns the DiscreteStates class for the HyperGrid environment.
- Return type:
- mode_mask(states)¶
Boolean mask indicating which states are in a mode.
A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).
- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
torch.Tensor
- modes_found(states)¶
Returns the set of canonical state indices for mode states in the batch.
Each mode state is identified by its unique canonical index (from
get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.- Parameters:
states (gfn.states.DiscreteStates)
- Return type:
set[int]
- property n_mode_states: int | float | None¶
Number of states inside a mode (exact, approx, or None).
If mode_stats=”exact”, returns an exact integer count.
If mode_stats=”approx”, returns a floating-point estimate.
If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.
Otherwise, returns None.
- Return type:
int | float | None
- property n_modes: int | float | None¶
Returns the total number of mode states for this environment.
Equivalent to
n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.- Return type:
int | float | None
- property n_states: int¶
Returns the number of states in the environment.
- Return type:
int
- property n_terminating_states: int¶
Returns the number of terminating states in the environment.
- Return type:
int
- ndim = 2¶
- reward(states)¶
Computes the reward for a batch of final states.
In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}
0.5 rightrvert in (0.25, 0.5] right)
2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`
- Parameters:
final_states – The final states.
states (gfn.states.DiscreteStates)
- Returns:
The reward of the final states.
- Return type:
torch.Tensor
- reward_fn¶
- reward_fn_kwargs = None¶
- step(states, actions)¶
Performs a step in the environment.
- Parameters:
states (gfn.states.DiscreteStates) – The current states.
actions (gfn.actions.Actions) – The actions to take.
- Returns:
The next states.
- Return type:
- store_all_states = False¶
- property terminating_states: gfn.states.DiscreteStates | None¶
Returns all terminating states of the environment.
- Return type:
gfn.states.DiscreteStates | None
- true_dist(condition=None)¶
Returns the pmf over all states in the hypergrid.
- Return type:
torch.Tensor | None
- class gfn.gym.hypergrid.MultiplicativeCoprimeReward(height, ndim, **kwargs)¶
Bases:
GridRewardTiered reward based on prime-support and coprimality/lcm composition.
Each tier enforces that per-dimension values use only a small shared prime set with bounded exponents, plus optional cross-dimension constraints (pairwise coprime pairs and/or target lcm). Higher tiers tighten exponent caps or add additional global targets. This encourages information sharing to learn the latent prime/exponent structure.
- Reward form:
R(s) = R0 + Σ_t tier_weights[t] · 1[ constraints_0..t all satisfied ]
- Key kwargs:
R0: float, base reward (default 0.0)
tier_weights: list[float]
primes: list[int], e.g., [2,3,5,7,11]
exponent_caps: list[int], same length as tier_weights. Cap for every prime at tier t (uniform cap across primes for simplicity).
active_dims: Optional[list[int]]; constraints only apply to these dims (default: all dims). Other dims are ignored in constraints.
coprime_pairs: Optional[list[tuple[int,int]]]; indices relative to active_dims.
target_lcms: Optional[list[int | None]]; per-tier target lcm across active dims.
Notes: - Values 0 are treated as invalid for prime-support constraints (cannot factorize);
value 1 is valid with all-zero exponents.
Implementation removes primes up to the current tier cap and checks residue == 1. Exponent counts are accumulated to evaluate LCM targets.
- Comparison with other compositional rewards:
BitwiseXORReward: GF(2) parity checks on bit-planes; same constraint type per tier (wider bit window), no cross-scale dependency.
ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; each tier’s rule depends on prior tiers.
This class: prime factorization with bounded exponents and optional coprimality/LCM targets. Same constraint type per tier (tighter caps), but cross-dimension coupling via coprime pairs and LCM targets.
- Parameters:
height (int)
ndim (int)
- R0: float¶
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- _factor_exponents_up_to_cap(v, cap)¶
Trial-divide each element by allowed primes, returning residue and exponents.
- Parameters:
v (torch.Tensor) – (…,) LongTensor of non-negative values to factorize.
cap (int) – Maximum number of times each prime may divide a value.
- Returns:
(…,) values after stripping allowed primes (1 if fully factored). exps: [num_primes, …] exponent counts per prime (leading axis is primes).
- Return type:
residue
- _lcm_ok(exps, target_lcm)¶
Check whether max exponents across dims match target LCM’s factorization.
- Parameters:
exps (torch.Tensor) – [num_primes, …, num_active_dims] exponent counts.
target_lcm (int) – The target LCM value to match.
- Returns:
(…,) bool mask, True where the LCM of active-dim values equals target.
- Return type:
torch.Tensor
- _pairwise_coprime_ok(v)¶
Check that configured dimension pairs share no common allowed prime.
- Parameters:
v (torch.Tensor) – (…, num_active_dims) coordinate values.
- Returns:
(…,) bool mask, True where all coprime pair constraints hold.
- Return type:
torch.Tensor
- active_dims: list[int]¶
- coprime_pairs¶
- exponent_caps: list[int]¶
- primes: list[int]¶
- target_lcms¶
- tier_weights: list[float]¶
- class gfn.gym.hypergrid.OriginalReward(height, ndim, **kwargs)¶
Bases:
GridRewardThe reward function from the original GFlowNet paper (Bengio et al., 2021; https://arxiv.org/abs/2106.04399).
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class gfn.gym.hypergrid.SparseReward(height, ndim, **kwargs)¶
Bases:
GridRewardSparse reward function from the GAFN paper (Pan et al., 2022; https://arxiv.org/abs/2210.03308).
- Parameters:
height (int)
ndim (int)
- __call__(states_tensor)¶
- Parameters:
states_tensor (torch.Tensor)
- Return type:
torch.Tensor
- targets¶
- gfn.gym.hypergrid._first_k_dims(k, ndim)¶
Return indices [0, 1, …, min(k, ndim)-1] for the first k dimensions.
- Parameters:
k (int)
ndim (int)
- Return type:
list[int]
- gfn.gym.hypergrid.get_bitwise_xor_presets(ndim, height)¶
Return five difficulty presets for BitwiseXORReward.
The presets target approximate L1 distance bands by selecting the highest constrained bit and number of constrained dimensions. Typical distance scales like m · 2^b, where m is the number of constrained dims and b the highest bit.
- Bands (steps from s0):
easy: ~50-100
medium: ~250-500
hard: ~1k-2.5k
challenging: ~2.5k-5k
impossible: 5k+
Notes - You may tweak m (dims) and bit windows to fine-tune distances for your D,H. - Tier weights are geometric to encourage reaching higher tiers. - Parity checks default to even parity across constrained dims per bit-plane.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_conditional_multiscale_presets(ndim, height)¶
Return five difficulty presets for ConditionalMultiScaleReward.
All presets use base=4 (requiring H to be a power of 4). The number of available digit levels is L = log_4(H). Difficulty is controlled by:
Number of tiers (more tiers = deeper compositional hierarchy)
Number of active dims (more = exponentially sparser modes)
Cross-dim modular constraints (further sparsification)
- Mode counts are computed via:
modes_T = (f^T * 4^{L-T})^d / prod_t m_t
with f = filter_width = 2 (i.e. B//2), so each tier halves modes per coord.
- Presets (assuming H=256, i.e. L=4 digit levels):
easy: 2 tiers, 3 active dims -> ~2M modes at tier 2
medium: 3 tiers, 4 active dims -> ~1M modes at tier 3
hard: 3 tiers, 6 active dims, cross-dim -> ~260K modes at tier 3
challenging: 4 tiers, 8 active dims, cross-dim -> ~65K modes at tier 4
impossible: 4 tiers, 12 active dims, cross-dim -> ~4K modes at tier 4
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_cosine_presets(ndim, height)¶
Return five presets for CosineReward.
R1 scales the oscillatory product, and mode_gamma (used only for mode detection thresholding) tightens what is considered a “mode-like” maximum.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_deceptive_presets(ndim, height)¶
Return five presets for DeceptiveReward.
Increase R2 to accentuate the thin band, and set a small but non-zero R0. R1 controls the center emphasis vs. the cancelled outer region.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_multiplicative_coprime_presets(ndim, height)¶
Return five difficulty presets for MultiplicativeCoprimeReward.
- Bands (steps from s0):
easy: ~50-100 (small primes, small exponents, few active dims)
medium: ~250-500 (adds one prime, caps=2, more dims, light coupling)
hard: ~1k-2.5k (primes up to 11, caps=3, more dims, LCM target)
challenging: ~2.5k-5k (primes up to 13, caps=3-4, 10-12 dims, tighter)
impossible: 5k+ (primes up to 29, caps=4, 12-16 dims, multiple targets)
Notes - Distances are approximate; increase primes and exponent caps to push further. - active_dims indexes are relative to state dims; we pick first k for simplicity. - coprime_pairs are pairs within active_dims index space. - Tier weights are geometric.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_original_presets(ndim, height)¶
Return five presets for OriginalReward.
These presets primarily control the relative importance of the outer ring (R1) and thin band (R2). Exploration difficulty (distance from s0) is more a function of (D, H) than of these weights; tune D and H externally to match your distance bands.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_reward_presets(reward_fn_str, ndim, height)¶
Return presets for a given reward name: ‘bitwise_xor’, ‘multiplicative_coprime’, ‘conditional_multiscale’.
Usage¶
presets = get_reward_presets(“bitwise_xor”, D, H) kwargs = presets[“hard”] env = HyperGrid(ndim=D, height=H, reward_fn_str=”bitwise_xor”, reward_fn_kwargs=kwargs)
- Parameters:
reward_fn_str (str)
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.get_sparse_presets(ndim, height)¶
Return five presets for SparseReward.
SparseReward has built-in targets; it ignores most kwargs. Presets are provided for API symmetry and future extensibility.
- Parameters:
ndim (int)
height (int)
- Return type:
dict
- gfn.gym.hypergrid.lcm(a, b)¶
Returns the lowest common multiple between a and b.
- gfn.gym.hypergrid.lcm_multiple(numbers)¶
Find the lowest common multiple across a list of numbers
- gfn.gym.hypergrid.logger¶
- gfn.gym.hypergrid.smallest_multiplier_to_integers(float_vector, precision=3)¶
Used to calculate a scale factor to avoid imprecise floating point arithmetic.