gfn.gym.hypergrid

Adapted from https://github.com/Tikquuss/GflowNets_Tutorial

Attributes

logger

Classes

BitwiseXORReward

Tiered, compositional reward based on bitwise XOR/parity constraints.

ConditionalHyperGrid

HyperGrid environment with condition-aware rewards.

ConditionalMultiScaleReward

Tiered reward via conditional digit constraints across spatial scales.

CosineReward

Cosine reward function.

DeceptiveReward

Deceptive reward function from Adaptive Teachers (Kim et al., 2025).

GridReward

Base class for reward functions that can be pickled.

HyperGrid

HyperGrid environment from the GFlowNets paper.

MultiplicativeCoprimeReward

Tiered reward based on prime-support and coprimality/lcm composition.

OriginalReward

The reward function from the original GFlowNet paper (Bengio et al., 2021;

SparseReward

Sparse reward function from the GAFN paper (Pan et al., 2022;

Functions

_first_k_dims(k, ndim)

Return indices [0, 1, ..., min(k, ndim)-1] for the first k dimensions.

get_bitwise_xor_presets(ndim, height)

Return five difficulty presets for BitwiseXORReward.

get_conditional_multiscale_presets(ndim, height)

Return five difficulty presets for ConditionalMultiScaleReward.

get_cosine_presets(ndim, height)

Return five presets for CosineReward.

get_deceptive_presets(ndim, height)

Return five presets for DeceptiveReward.

get_multiplicative_coprime_presets(ndim, height)

Return five difficulty presets for MultiplicativeCoprimeReward.

get_original_presets(ndim, height)

Return five presets for OriginalReward.

get_reward_presets(reward_fn_str, ndim, height)

Return presets for a given reward name: 'bitwise_xor', 'multiplicative_coprime', 'conditional_multiscale'.

get_sparse_presets(ndim, height)

Return five presets for SparseReward.

lcm(a, b)

Returns the lowest common multiple between a and b.

lcm_multiple(numbers)

Find the lowest common multiple across a list of numbers

smallest_multiplier_to_integers(float_vector[, precision])

Used to calculate a scale factor to avoid imprecise floating point arithmetic.

Module Contents

class gfn.gym.hypergrid.BitwiseXORReward(height, ndim, **kwargs)

Bases: GridReward

Tiered, compositional reward based on bitwise XOR/parity constraints.

This class implements the “Bitwise/XOR fractal” environment family: where tiers progressively constrain bit-planes across a subset of dimensions via linear parity checks over GF(2). It supports easy sharding by high-bit prefixes, and difficulty control by adjusting which bit-planes and how many dimensions are constrained per tier.

GF(2) is the finite field with two elements {0, 1}, where addition and multiplication are performed modulo 2. In this context, vector addition is equivalent to bitwise XOR, and matrix-vector products (A @ b) are evaluated entrywise modulo 2.

Reward form:

R(s) = R0 + Σ_t tier_weights[t] · 1[ state satisfies all constraints up to tier t ]

Key kwargs (with reasonable defaults):
  • R0: float, base reward (default 0.0)

  • tier_weights: list[float], strictly increasing weights for each tier

  • dims_constrained: Optional[list[int]] subset of dims to constrain (default: all dims)

  • bits_per_tier: list[tuple[int,int]]; for each tier t, inclusive bit range (low_bit, high_bit). Example: [(0,5), (0,7), (0,9)].

  • parity_checks: Optional[list[dict]]; per tier, optional parity system:
    Each entry may contain:

    { “A”: IntTensor[num_checks, m], “c”: IntTensor[num_checks] }

    where m = len(dims_constrained). Constraints apply identically to every bit-plane specified for that tier: A @ b(mod2) == c, where b are the bit values across constrained dimensions at the tested bit-plane. If omitted for a tier, a single even-parity check across all constrained dims is used by default: sum(b) mod 2 == 0.

Difficulty presets align with step ranges by controlling the highest bit used and the number of constrained dimensions. Typical distance from origin for valid modes scales roughly like (constrained_dims · 2^{highest_bit}).

Comparison with other compositional rewards:
  • MultiplicativeCoprimeReward: number-theoretic (prime factorization); same constraint type per tier (tighter exponent caps), no cross-scale dependency.

  • ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; each tier’s rule is a function of prior tiers, introducing qualitatively different structure at each level.

  • This class: GF(2) linear algebra on bit-planes; same parity check type per tier, but applied to progressively wider bit windows. Constraints at each bit-plane are independent of other planes.

Parameters:
  • height (int)

  • ndim (int)

R0: float
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_apply_parity_checks(bits_plane, tier_idx)

Apply GF(2) linear parity checks at a single bit-plane.

bits_plane: (…, m) with m=len(dims_constrained), integer in {0,1}. Returns mask (…,) bool.

Parameters:
  • bits_plane (torch.Tensor)

  • tier_idx (int)

Return type:

torch.Tensor

_even_parity_mask(bits)

bits: (…, m) int/bool -> returns (…,) bool for even parity.

Parameters:

bits (torch.Tensor)

Return type:

torch.Tensor

bits_per_tier: list[tuple[int, int]]
dims_constrained: list[int]
parity_checks
tier_weights: list[float]
class gfn.gym.hypergrid.ConditionalHyperGrid(*args, **kwargs)

Bases: HyperGrid

HyperGrid environment with condition-aware rewards.

Let condition ‘c’ be a real value in [0, 1]. It defines the reward as a linear interpolation between the uniform reward and the original reward. Special cases are: - c = 0: Uniform reward (all terminal states get reward=R0+R1+R2) - c = 1: Original HyperGrid reward (original multi-modal reward landscape)

_log_partition_cache: dict[torch.Tensor, float]
_max_reward: float
_original_reward_fn
_true_dist_cache: dict[torch.Tensor, torch.Tensor]
condition_dim: int = 1
is_conditional: bool = True
log_partition(condition)

Compute the log partition for the given condition.

Parameters:

condition (torch.Tensor) – The condition to compute the log partition for. condition.shape should be (1,)

Returns:

The log partition function, as a float.

Return type:

float

reward(states)

Compute rewards for the conditional environment.

A condition is continuous from 0 to 1: - 0: Fully uniform reward (all states get R0+R1+R2) - 1: Fully original HyperGrid reward - In between: Linear interpolation between uniform and original

Parameters:

states (gfn.states.DiscreteStates) – The states to compute rewards for. states.tensor.shape should be (*batch_shape, *state_shape)

Returns:

A tensor of shape (*batch_shape,) containing the rewards.

Return type:

torch.Tensor

sample_conditions(batch_shape)

Sample conditions for the environment.

Parameters:

batch_shape (int | tuple[int, Ellipsis])

Return type:

torch.Tensor

true_dist(condition)

Compute the true distribution for the given condition.

Parameters:
  • condition (torch.Tensor) – The condition to compute the true distribution for.

  • be (condition.shape should)

Returns:

The true distribution for the given condition as a 1-dimensional tensor.

Return type:

torch.Tensor

class gfn.gym.hypergrid.ConditionalMultiScaleReward(height, ndim, **kwargs)

Bases: GridReward

Tiered reward via conditional digit constraints across spatial scales.

Each coordinate is decomposed in base B into L = log_B(H) digits. Tier t constrains digit t-1 via a shifted filter that depends on all finer-scale digits, creating a hierarchy where learning lower-scale structure is prerequisite for predicting higher-scale constraints.

Per-dimension constraint at tier t:

(d_{t-1}(i) + sigma_t(i)) mod B < f_t

where sigma_t(i) = sum_{k=0}^{t-2} a_{t,k} * d_k(i) mod B is a linear function of finer-scale digits with seed-derived coefficients a_{t,k}.

Optional cross-dimensional constraint at tier t:

sum_i d_{t-1}(i) ≡ 0 (mod m_t)

Reward form (cumulative — tier t requires all tiers 1..t):

R(s) = R0 + sum_t tier_weights[t] * 1[s satisfies tiers 1..t]

Mode count (exact closed form):

Without cross-dim: modes_T = (prod_{t=1}^T f_t)^d * B^{(L-T)*d} With cross-dim: modes_T = (prod_t f_t)^d * B^{(L-T)*d} / prod_t m_t

Partition function (analytic, no enumeration):

Z = R0 * H^d + sum_t w_t * modes_t

Key kwargs:
  • R0: float, base reward (default 0.0).

  • tier_weights: list[float], reward weight per tier.

  • base: int, digit base B (default 4). H must be a power of B.

  • filter_width: int, number of passing digit values per tier (default B//2). Constant across tiers to avoid mode collapse at deep tiers.

  • seed: int, PRNG seed for generating shift coefficients (default 42).

  • cross_dim_mods: Optional[list[int|None]], per-tier modular cross-dim constraint. m_t must divide filter_width for exact mode counts. Default: no cross-dim constraints.

  • active_dims: Optional[list[int]], subset of dims to constrain (default: all dims).

Comparison with other compositional rewards:
  • BitwiseXORReward: GF(2) parity checks on bit-planes; each tier widens the bit window but uses the same rule type. No cross-scale dependency.

  • MultiplicativeCoprimeReward: prime factorization with tightening exponent caps. Same constraint type at every tier.

  • This class: each tier introduces a qualitatively different constraint whose form depends on what was learned at prior tiers (conditional structure across scales).

Parameters:
  • height (int)

  • ndim (int)

R0: float
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_extract_digits(x, num_levels)

Extract base-B digits from x.

Parameters:
  • x (torch.Tensor) – (…, m) integer tensor with coordinate values.

  • num_levels (int) – how many digit levels to extract.

Returns:

List of num_levels tensors, each (…, m), with digit values in [0, B). digits[k] is the k-th digit (scale k), i.e., floor(x / B^k) mod B.

Return type:

list[torch.Tensor]

active_dims: list[int]
analytic_log_partition()

Compute log(Z) analytically.

Z = R0 * H^d + sum_t w_t * modes_t

Return type:

float

analytic_mode_count(tier=None)

Compute exact mode count for a given tier (1-indexed) or all tiers.

Parameters:

tier (int | None) – 1-indexed tier number. If None, returns count for the highest tier (most constrained).

Returns:

Number of states satisfying all constraints up to the given tier.

Return type:

int

base: int
cross_dim_mods: list[int | None] = []
filter_width: int
mode_threshold(target_sparsity=0.1)

Return the reward threshold for mode counting at the adaptive tier.

States with reward >= this value are counted as modes. The tier is chosen via mode_tier(target_sparsity) so that mode coverage adapts to dimensionality.

Parameters:

target_sparsity (float) – Passed to mode_tier().

Returns:

Reward threshold (R0 + sum of weights up to the mode tier).

Return type:

float

mode_tier(target_sparsity=0.1)

Return the lowest tier whose mode coverage is below target_sparsity.

Coverage at tier t = (f/B)^(t*d) (before cross-dim constraints). This adapts the mode definition to dimensionality: at low d, deeper tiers are needed for modes to be sparse; at high d, even tier 1 is already a needle in a haystack.

Parameters:

target_sparsity (float) – Fraction of total state space below which modes are considered “interesting” (default 0.10 = 10%).

Returns:

1-indexed tier number. Clamped to [1, num_tiers].

Return type:

int

num_levels: int = 0
seed: int
shift_coeffs: list[list[int]] = []
tier_weights: list[float]
class gfn.gym.hypergrid.CosineReward(height, ndim, **kwargs)

Bases: GridReward

Cosine reward function.

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.DeceptiveReward(height, ndim, **kwargs)

Bases: GridReward

Deceptive reward function from Adaptive Teachers (Kim et al., 2025).

Note that the reward definition in the paper (eq. (9)) is incorrect, and we follow the official implementation (https://github.com/alstn12088/adaptive-teacher/blob/8cfcb2298fce3f46eb36ead03791eeee75b7d066/grid/env.py#L27) while modifying it to use EPS = 1e-12 to handle inequalities with floating points.

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.GridReward(height, ndim, **kwargs)

Bases: abc.ABC

Base class for reward functions that can be pickled.

Parameters:
  • height (int)

  • ndim (int)

_EPS = 1e-12
abstract __call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

height
kwargs
ndim
class gfn.gym.hypergrid.HyperGrid(ndim=2, height=8, reward_fn_str='original', reward_fn_kwargs=None, device='cpu', calculate_partition=False, store_all_states=False, debug=False, validate_modes=True, mode_stats='none', mode_stats_samples=20000)

Bases: gfn.env.DiscreteEnv

HyperGrid environment from the GFlowNets paper.

The states are represented as 1-d tensors of length ndim with values in {0, 1, …, height - 1}.

Parameters:
  • ndim (int)

  • height (int)

  • reward_fn_str (str)

  • reward_fn_kwargs (dict | None)

  • device (Literal['cpu', 'cuda'] | torch.device)

  • calculate_partition (bool)

  • store_all_states (bool)

  • debug (bool)

  • validate_modes (bool)

  • mode_stats (Literal['none', 'approx', 'exact'])

  • mode_stats_samples (int)

ndim

The dimension of the grid.

height

The height of the grid.

reward_fn

The reward function.

calculate_partition

Whether to calculate the log partition function.

store_all_states

Whether to store all states.

validate_modes

Whether to check that at least one state reaches the mode threshold at init; raises if not.

mode_stats

One of {“none”, “approx”, “exact”}. If not “none”, computes (exact or approximate) n_modes and n_mode_states. “exact” requires store_all_states=True and enumerates all states.

mode_stats_samples

Number of random samples when mode_stats=”approx”.

States: type[gfn.states.DiscreteStates]
_all_states_tensor = None
_enumerate_all_states_tensor(batch_size=20000)

Enumerate all grid states, optionally storing them and computing log Z.

Iterates over the full Cartesian product {0, ..., H-1}^D in batches (via multiprocessing) to avoid materializing all H^D states at once.

Parameters:

batch_size (int) – Number of states per batch.

_exists_bitwise_xor(thr)

Feasibility and constructive check for BitwiseXORReward.

Steps: - For each tier, verify the GF(2) parity system has at least one

solution using Gaussian elimination modulo 2. If any tier is infeasible, no mode exists.

  • The all-zero configuration satisfies even-parity constraints, so if tiers are feasible we evaluate that point against the threshold with tolerance.

Parameters:

thr (float)

Return type:

bool

_exists_cosine(thr)

Analytic upper-bound check for CosineReward.

Idea: - The per-dimension factor is (cos(50·ax) + 1) · N(0,1)(5·ax) with

ax in [0,0.5]. We estimate its maximum over the discrete grid by evaluating all candidate ax and taking the maximum value m.

  • The full reward upper bound is R0 + m^D * R1. If this is at least the mode target and the given threshold, a mode-level state must exist.

  • We also compute a theoretical per-dimension peak (at ax≈0) to form a slightly conservative target scaled by mode_gamma.

Parameters:

thr (float)

Return type:

bool

_exists_fallback_random(thr)

Random sampling fallback.

Draw a modest batch of random states on CPU and accept if any exceed the threshold with a small tolerance. This is a last resort to avoid expensive enumeration for large grids.

Parameters:

thr (float)

Return type:

bool

_exists_multiplicative_coprime(thr)

Number-theoretic constructive check for MultiplicativeCoprimeReward.

Constructs a candidate state by factoring the target LCM (if any) over the allowed primes, assigning each prime power to a separate active dimension, and verifying coprimality and grid-bound constraints.

Parameters:

thr (float)

Return type:

bool

_exists_original_or_deceptive(thr)

Constructive check for OriginalReward and DeceptiveReward.

Intuition: - These rewards form rings/bands around the center when each coordinate

is normalized to [0,1]. The mode lies on a thin band at specific normalized distances from the center.

  • We translate those fractional band boundaries into integer indices via small inside/outside nudges (using EPS_INDEX_CMP) and test one candidate index from any non-empty feasible interval.

  • If the reward at that candidate exceeds the threshold (with EPS_REWARD_CMP tolerance), we return True.

Parameters:

thr (float)

Return type:

bool

_exists_sparse(thr)

Constructive check for SparseReward.

This reward assigns positive mass only to a finite set of target configurations. When H>=2 and D>=1, a known target is the zero vector except for certain coordinates fixed at 1 or H-2. We probe a canonical target and confirm the threshold is not above its reward.

Parameters:

thr (float)

Return type:

bool

_generate_combinations_in_batches(ndim, max_val, batch_size)

Yield batches of the Cartesian product {0, …, max_val}^ndim.

Uses multiprocessing to avoid materializing the full product (size (max_val+1)^ndim) in memory.

Parameters:
  • ndim (int) – Number of dimensions (tuple length).

  • max_val (int) – Maximum coordinate value (inclusive).

  • batch_size (int) – Number of tuples per batch.

Yields:

An iterator of tuples for each batch.

_log_partition = None
_mode_reward_threshold()

Returns the reward threshold used to define a mode.

By default, a state is considered in a mode if its reward is at least the schema-defined threshold derived from the configured reward.

Return type:

float

_mode_stats_kind: str = 'none'
_modes_exist_quick_check()

Lightweight check that a mode-level state exists.

In simple terms, this answers: “Is there at least one state whose reward reaches the mode threshold?” without enumerating all states. It proceeds in three stages: 1) If the grid is small (or pre-enumerated), it computes rewards exactly

and checks against the threshold.

  1. Otherwise, it dispatches to reward-specific constructive tests that are sufficient to guarantee at least one state reaches the threshold.

  2. As a last resort, it samples a small batch of random states.

Return type:

bool

_modes_exist_quick_check_info()

Same as _modes_exist_quick_check but returns (ok, message).

Return type:

tuple[bool, str]

_n_mode_states_estimate: float | None = None
_n_mode_states_exact: int | None = None
static _solve_gf2_has_solution(A, c)

Return True if A x = c over GF(2) has at least one solution.

Performs Gaussian elimination modulo 2 (XOR arithmetic) without constructing a specific solution. A row that reduces to all-zero coefficients with a non-zero RHS (0 = 1) indicates inconsistency.

Parameters:
  • A (torch.Tensor)

  • c (torch.Tensor)

Return type:

bool

_true_dist = None
_worker(task)

Return a slice of the Cartesian product for one batch.

Parameters:

task (tuple) – (values, ndim, start_idx, end_idx) where values is the list of coordinate values, ndim is the number of dimensions, and [start_idx, end_idx) is the range within the full product.

Return type:

itertools.islice

all_indices()

Generate all possible indices for the grid.

Returns:

A list of all possible indices for the grid.

Return type:

List[Tuple[int, Ellipsis]]

property all_states: gfn.states.DiscreteStates | None

Returns a tensor of all hypergrid states as a DiscreteStates instance.

Return type:

gfn.states.DiscreteStates | None

backward_step(states, actions)

Performs a backward step in the environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.DiscreteStates

calculate_partition = False
get_states_indices(states)

Get the indices of the states in the canonical ordering.

Parameters:

states (gfn.states.DiscreteStates | torch.Tensor) – The states to get the indices of.

Returns:

The indices of the states in the canonical ordering.

Return type:

torch.Tensor

get_terminating_states_indices(states)

Get the indices of the terminating states in the canonical ordering.

Parameters:

states (gfn.states.DiscreteStates) – The states to get the indices of.

Returns:

The indices of the terminating states in the canonical ordering.

Return type:

torch.Tensor

height = 8
log_partition(condition=None)

Returns the log partition of the reward function.

Return type:

float | None

make_random_states(batch_shape, conditions=None, device=None, debug=False)

Creates a batch of random states.

Parameters:
  • batch_shape (Tuple[int, Ellipsis]) – The shape of the batch.

  • conditions (torch.Tensor | None) – Optional tensor of shape (*batch_shape, condition_dim) containing condition vectors for conditional GFlowNets.

  • device (torch.device | None) – The device to use.

  • debug (bool) – If True, emit States with debug guards (not compile-friendly).

Returns:

A DiscreteStates object with random states.

Return type:

gfn.states.DiscreteStates

make_states_class()

Returns the DiscreteStates class for the HyperGrid environment.

Return type:

type[gfn.states.DiscreteStates]

mode_mask(states)

Boolean mask indicating which states are in a mode.

A state is flagged as mode if its reward is greater-or-equal to the threshold based on reward_fn_kwargs (R0+R1+R2 by default).

Parameters:

states (gfn.states.DiscreteStates)

Return type:

torch.Tensor

modes_found(states)

Returns the set of canonical state indices for mode states in the batch.

Each mode state is identified by its unique canonical index (from get_states_indices), not by a quadrant-based grouping. This allows correct mode-state tracking for all reward functions.

Parameters:

states (gfn.states.DiscreteStates)

Return type:

set[int]

property n_mode_states: int | float | None

Number of states inside a mode (exact, approx, or None).

  • If mode_stats=”exact”, returns an exact integer count.

  • If mode_stats=”approx”, returns a floating-point estimate.

  • If store_all_states is True (but mode_stats was “none”), computes on demand from all_states.

  • Otherwise, returns None.

Return type:

int | float | None

property n_modes: int | float | None

Returns the total number of mode states for this environment.

Equivalent to n_mode_states. Each individual grid cell whose reward meets the mode threshold counts as one mode.

Return type:

int | float | None

property n_states: int

Returns the number of states in the environment.

Return type:

int

property n_terminating_states: int

Returns the number of terminating states in the environment.

Return type:

int

ndim = 2
reward(states)

Computes the reward for a batch of final states.

In the normal setting, the reward is: `R(s) = R_0 + 0.5 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1}

  • 0.5 rightrvert in (0.25, 0.5] right)

  • 2 prod_{d=1}^D mathbf{1} left( leftlvert frac{s^d}{H-1} - 0.5 rightrvert in (0.3, 0.4) right)`

Parameters:
Returns:

The reward of the final states.

Return type:

torch.Tensor

reward_fn
reward_fn_kwargs = None
step(states, actions)

Performs a step in the environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.DiscreteStates

store_all_states = False
property terminating_states: gfn.states.DiscreteStates | None

Returns all terminating states of the environment.

Return type:

gfn.states.DiscreteStates | None

true_dist(condition=None)

Returns the pmf over all states in the hypergrid.

Return type:

torch.Tensor | None

class gfn.gym.hypergrid.MultiplicativeCoprimeReward(height, ndim, **kwargs)

Bases: GridReward

Tiered reward based on prime-support and coprimality/lcm composition.

Each tier enforces that per-dimension values use only a small shared prime set with bounded exponents, plus optional cross-dimension constraints (pairwise coprime pairs and/or target lcm). Higher tiers tighten exponent caps or add additional global targets. This encourages information sharing to learn the latent prime/exponent structure.

Reward form:

R(s) = R0 + Σ_t tier_weights[t] · 1[ constraints_0..t all satisfied ]

Key kwargs:
  • R0: float, base reward (default 0.0)

  • tier_weights: list[float]

  • primes: list[int], e.g., [2,3,5,7,11]

  • exponent_caps: list[int], same length as tier_weights. Cap for every prime at tier t (uniform cap across primes for simplicity).

  • active_dims: Optional[list[int]]; constraints only apply to these dims (default: all dims). Other dims are ignored in constraints.

  • coprime_pairs: Optional[list[tuple[int,int]]]; indices relative to active_dims.

  • target_lcms: Optional[list[int | None]]; per-tier target lcm across active dims.

Notes: - Values 0 are treated as invalid for prime-support constraints (cannot factorize);

value 1 is valid with all-zero exponents.

  • Implementation removes primes up to the current tier cap and checks residue == 1. Exponent counts are accumulated to evaluate LCM targets.

Comparison with other compositional rewards:
  • BitwiseXORReward: GF(2) parity checks on bit-planes; same constraint type per tier (wider bit window), no cross-scale dependency.

  • ConditionalMultiScaleReward: base-B digit decomposition with conditional constraints across scales; each tier’s rule depends on prior tiers.

  • This class: prime factorization with bounded exponents and optional coprimality/LCM targets. Same constraint type per tier (tighter caps), but cross-dimension coupling via coprime pairs and LCM targets.

Parameters:
  • height (int)

  • ndim (int)

R0: float
__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

_factor_exponents_up_to_cap(v, cap)

Trial-divide each element by allowed primes, returning residue and exponents.

Parameters:
  • v (torch.Tensor) – (…,) LongTensor of non-negative values to factorize.

  • cap (int) – Maximum number of times each prime may divide a value.

Returns:

(…,) values after stripping allowed primes (1 if fully factored). exps: [num_primes, …] exponent counts per prime (leading axis is primes).

Return type:

residue

_lcm_ok(exps, target_lcm)

Check whether max exponents across dims match target LCM’s factorization.

Parameters:
  • exps (torch.Tensor) – [num_primes, …, num_active_dims] exponent counts.

  • target_lcm (int) – The target LCM value to match.

Returns:

(…,) bool mask, True where the LCM of active-dim values equals target.

Return type:

torch.Tensor

_pairwise_coprime_ok(v)

Check that configured dimension pairs share no common allowed prime.

Parameters:

v (torch.Tensor) – (…, num_active_dims) coordinate values.

Returns:

(…,) bool mask, True where all coprime pair constraints hold.

Return type:

torch.Tensor

active_dims: list[int]
coprime_pairs
exponent_caps: list[int]
primes: list[int]
target_lcms
tier_weights: list[float]
class gfn.gym.hypergrid.OriginalReward(height, ndim, **kwargs)

Bases: GridReward

The reward function from the original GFlowNet paper (Bengio et al., 2021; https://arxiv.org/abs/2106.04399).

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

class gfn.gym.hypergrid.SparseReward(height, ndim, **kwargs)

Bases: GridReward

Sparse reward function from the GAFN paper (Pan et al., 2022; https://arxiv.org/abs/2210.03308).

Parameters:
  • height (int)

  • ndim (int)

__call__(states_tensor)
Parameters:

states_tensor (torch.Tensor)

Return type:

torch.Tensor

targets
gfn.gym.hypergrid._first_k_dims(k, ndim)

Return indices [0, 1, …, min(k, ndim)-1] for the first k dimensions.

Parameters:
  • k (int)

  • ndim (int)

Return type:

list[int]

gfn.gym.hypergrid.get_bitwise_xor_presets(ndim, height)

Return five difficulty presets for BitwiseXORReward.

The presets target approximate L1 distance bands by selecting the highest constrained bit and number of constrained dimensions. Typical distance scales like m · 2^b, where m is the number of constrained dims and b the highest bit.

Bands (steps from s0):
  • easy: ~50-100

  • medium: ~250-500

  • hard: ~1k-2.5k

  • challenging: ~2.5k-5k

  • impossible: 5k+

Notes - You may tweak m (dims) and bit windows to fine-tune distances for your D,H. - Tier weights are geometric to encourage reaching higher tiers. - Parity checks default to even parity across constrained dims per bit-plane.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_conditional_multiscale_presets(ndim, height)

Return five difficulty presets for ConditionalMultiScaleReward.

All presets use base=4 (requiring H to be a power of 4). The number of available digit levels is L = log_4(H). Difficulty is controlled by:

  • Number of tiers (more tiers = deeper compositional hierarchy)

  • Number of active dims (more = exponentially sparser modes)

  • Cross-dim modular constraints (further sparsification)

Mode counts are computed via:

modes_T = (f^T * 4^{L-T})^d / prod_t m_t

with f = filter_width = 2 (i.e. B//2), so each tier halves modes per coord.

Presets (assuming H=256, i.e. L=4 digit levels):
  • easy: 2 tiers, 3 active dims -> ~2M modes at tier 2

  • medium: 3 tiers, 4 active dims -> ~1M modes at tier 3

  • hard: 3 tiers, 6 active dims, cross-dim -> ~260K modes at tier 3

  • challenging: 4 tiers, 8 active dims, cross-dim -> ~65K modes at tier 4

  • impossible: 4 tiers, 12 active dims, cross-dim -> ~4K modes at tier 4

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_cosine_presets(ndim, height)

Return five presets for CosineReward.

R1 scales the oscillatory product, and mode_gamma (used only for mode detection thresholding) tightens what is considered a “mode-like” maximum.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_deceptive_presets(ndim, height)

Return five presets for DeceptiveReward.

Increase R2 to accentuate the thin band, and set a small but non-zero R0. R1 controls the center emphasis vs. the cancelled outer region.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_multiplicative_coprime_presets(ndim, height)

Return five difficulty presets for MultiplicativeCoprimeReward.

Bands (steps from s0):
  • easy: ~50-100 (small primes, small exponents, few active dims)

  • medium: ~250-500 (adds one prime, caps=2, more dims, light coupling)

  • hard: ~1k-2.5k (primes up to 11, caps=3, more dims, LCM target)

  • challenging: ~2.5k-5k (primes up to 13, caps=3-4, 10-12 dims, tighter)

  • impossible: 5k+ (primes up to 29, caps=4, 12-16 dims, multiple targets)

Notes - Distances are approximate; increase primes and exponent caps to push further. - active_dims indexes are relative to state dims; we pick first k for simplicity. - coprime_pairs are pairs within active_dims index space. - Tier weights are geometric.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_original_presets(ndim, height)

Return five presets for OriginalReward.

These presets primarily control the relative importance of the outer ring (R1) and thin band (R2). Exploration difficulty (distance from s0) is more a function of (D, H) than of these weights; tune D and H externally to match your distance bands.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_reward_presets(reward_fn_str, ndim, height)

Return presets for a given reward name: ‘bitwise_xor’, ‘multiplicative_coprime’, ‘conditional_multiscale’.

Usage

presets = get_reward_presets(“bitwise_xor”, D, H) kwargs = presets[“hard”] env = HyperGrid(ndim=D, height=H, reward_fn_str=”bitwise_xor”, reward_fn_kwargs=kwargs)

Parameters:
  • reward_fn_str (str)

  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.get_sparse_presets(ndim, height)

Return five presets for SparseReward.

SparseReward has built-in targets; it ignores most kwargs. Presets are provided for API symmetry and future extensibility.

Parameters:
  • ndim (int)

  • height (int)

Return type:

dict

gfn.gym.hypergrid.lcm(a, b)

Returns the lowest common multiple between a and b.

gfn.gym.hypergrid.lcm_multiple(numbers)

Find the lowest common multiple across a list of numbers

gfn.gym.hypergrid.logger
gfn.gym.hypergrid.smallest_multiplier_to_integers(float_vector, precision=3)

Used to calculate a scale factor to avoid imprecise floating point arithmetic.