gfn.gym.helpers.box_utils

Backward-compatibility shim for Box environment utilities.

Import from box_cartesian_utils or box_polar_utils directly for new code.

Classes

BoxCartesianDistribution

Cartesian increment distribution for Box environment.

BoxCartesianPBDistribution

Backward Cartesian distribution for Box environment.

BoxCartesianPBEstimator

Simplified PB estimator using Cartesian increments with back-to-source.

BoxCartesianPBMLP

MLP for Box backward policy. First output is back-to-start logit.

BoxCartesianPFEstimator

Simplified PF estimator using Cartesian increments.

BoxCartesianPFMLP

MLP for Box forward policy. First output is exit logit.

BoxPBEstimator

Estimator for P_B for the Box environment.

BoxPBMLP

A MLP for the backward policy of the Box environment.

BoxPBUniform

Uniform backward policy for the polar Box environment.

BoxPFEstimator

Estimator for P_F for the Box environment.

BoxPFMLP

A MLP for the forward policy of the Box environment.

BoxStateFlowModule

A MLP for the state flow function of the Box environment.

DistributionWrapper

A wrapper that combines QuarterDisk and QuarterCircleWithExit.

QuarterCircle

Represents distributions on quarter circles.

QuarterCircleWithExit

Extends QuarterCircle with an exit action.

QuarterDisk

Represents a distribution on the northeastern quarter disk.

UniformBoxCartesianPBModule

Fixed (non-learned) backward policy module for Cartesian Box.

Functions

split_PF_module_output(output, n_comp_max)

Splits the module output into the expected parameter sets.

Module Contents

class gfn.gym.helpers.box_utils.BoxCartesianDistribution(states, exit_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)

Bases: _BetaMixtureMixin, torch.distributions.Distribution

Cartesian increment distribution for Box environment.

Uses MixtureSameFamily(Categorical, Beta) per dimension for sampling increments. Much simpler than polar coordinates - samples relative increments per dimension and converts to absolute using: action = min_incr + r * (max_range).

Parameters:
  • states (gfn.states.States)

  • exit_logits (torch.Tensor)

  • mixture_logits (torch.Tensor)

  • alpha (torch.Tensor)

  • beta (torch.Tensor)

  • delta (float)

  • epsilon (float)

  • temperature (float)

delta

Minimum step size.

epsilon

Small value for numerical stability.

alpha
arg_constraints: dict
at_boundary
beta
delta
epsilon = 1e-06
exit_logits_scaled
is_s0
log_prob(actions)

Compute log probability using Cartesian per-dimension approach.

Parameters:

actions (torch.Tensor)

Return type:

torch.Tensor

log_weights
max_range
min_incr
n_dim
sample(sample_shape=Size())

Sample actions using Cartesian per-dimension increments.

Parameters:

sample_shape (torch.Size)

Return type:

torch.Tensor

states
class gfn.gym.helpers.box_utils.BoxCartesianPBDistribution(states, bts_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)

Bases: _BetaMixtureMixin, torch.distributions.Distribution

Backward Cartesian distribution for Box environment.

In torchgfn’s design, the source state is the origin [0, 0]. The BTS (back-to-source) action moves directly to s0 by setting action = state.

WHY THE LEARNED BTS BERNOULLI IS CRITICAL

The Trajectory Balance (TB) loss is:

L = (log P_F(τ) + log Z - log P_B(τ) - log R(x_T))^2

The BTS step (x_1 → s0) is the last step of every backward trajectory. Before this fix, BTS was always deterministic (forced): log P_B(BTS | x_1) = 0. A constant zero drops out of the gradient, so P_B received no gradient from the BTS step — regardless of trajectory length.

For 1-step trajectories (s0 → x_T, then immediately BTS back to s0), P_B was entirely gradient-free: log P_B(τ) = 0 for every such trajectory, making P_B invisible to the TB loss. This is particularly harmful because the reward landscape is dominated by states reachable from s0 in one step (the high-reward ring at |x - 0.5| ∈ (0.3, 0.4)).

With a learned Bernoulli P(BTS | x): - log P_B(BTS | x_1) = log P(BTS=1 | x_1) — gradient flows into P_B - log P_B(~BTS | x_1) = log P(BTS=0 | x_1) — also receives gradient This closes the TB loop fully: P_B now has an incentive to assign higher probability to BTS from states that are indeed close to the reward modes, and lower probability from states that are far away.

FORCED vs. STOCHASTIC BTS

When any dimension of the state is <= delta, the valid backward increment range [delta, state[d]] is empty for that dimension. BTS is the only valid action, so it is forced (log_prob = 0, deterministic). For all other states, BTS is an optional choice sampled from the learned Bernoulli.

alpha
any_dim_near_origin
arg_constraints: dict
beta
bts_logits_scaled
delta
dim_near_origin
epsilon = 1e-06
log_prob(actions)

Compute log probability of backward actions.

  • Forced BTS (any_dim_near_origin): log_prob = 0 (deterministic).

  • Stochastic BTS: log_prob = log P(BTS=1 | s) from Bernoulli.

  • Non-BTS: log_prob = log P(BTS=0 | s) + log_p_beta + log_jacobian.

Parameters:

actions (torch.Tensor)

Return type:

torch.Tensor

log_weights
max_range
min_incr
n_dim
sample(sample_shape=Size())

Sample backward actions.

BTS (action = state) is forced when near origin; otherwise sampled from the learned Bernoulli. Non-BTS samples come from Beta mixture.

Parameters:

sample_shape (torch.Size)

Return type:

torch.Tensor

states
Parameters:
  • states (gfn.states.States)

  • bts_logits (torch.Tensor)

  • mixture_logits (torch.Tensor)

  • alpha (torch.Tensor)

  • beta (torch.Tensor)

  • delta (float)

  • epsilon (float)

  • temperature (float)

class gfn.gym.helpers.box_utils.BoxCartesianPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)

Bases: gfn.estimators.Estimator, gfn.estimators.PolicyMixin

Simplified PB estimator using Cartesian increments with back-to-source.

Parameters:
  • env (gfn.gym.box.BoxPolar)

  • module (torch.nn.Module)

  • n_components (int)

  • min_concentration (float)

  • max_concentration (float)

  • numerical_epsilon (float)

  • debug (bool)

delta
epsilon
property expected_output_dim: int

bts_logit + (weights + alpha + beta) * n_dim * n_comp.

Type:

Expected output dimension

Return type:

int

max_concentration = 100.0
min_concentration = 0.1
n_components
n_dim = 2
numerical_epsilon = 1e-06
temperature: float = 1.0
to_probability_distribution(states, module_output)

Convert module output to backward probability distribution.

Parameters:
Return type:

torch.distributions.Distribution

classmethod uniform(env, n_components, **kwargs)

Create an estimator with a fixed (non-learned) uniform backward policy.

Parameters:
  • env (gfn.gym.box.BoxPolar) – The Box environment.

  • n_components (int) – Number of mixture components.

  • **kwargs – Extra keyword arguments forwarded to the constructor.

Returns:

A BoxCartesianPBEstimator whose module has no learnable parameters.

Return type:

BoxCartesianPBEstimator

class gfn.gym.helpers.box_utils.BoxCartesianPBMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)

Bases: _BoxCartesianMLP

MLP for Box backward policy. First output is back-to-start logit.

Parameters:
  • hidden_dim (int)

  • n_hidden_layers (int)

  • n_components (int)

  • n_dim (int)

  • kwargs (Any)

class gfn.gym.helpers.box_utils.BoxCartesianPFEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)

Bases: gfn.estimators.Estimator, gfn.estimators.PolicyMixin

Simplified PF estimator using Cartesian increments.

Much simpler than BoxPFEstimator - uses a single MLP and BoxCartesianDistribution.

Parameters:
  • env (gfn.gym.box.BoxPolar)

  • module (torch.nn.Module)

  • n_components (int)

  • min_concentration (float)

  • max_concentration (float)

  • numerical_epsilon (float)

  • debug (bool)

delta
epsilon
property expected_output_dim: int

exit_logit + (weights + alpha + beta) * n_dim * n_comp.

Type:

Expected output dimension

Return type:

int

max_concentration = 100.0
min_concentration = 0.1
n_components
n_dim = 2
numerical_epsilon = 1e-06
temperature: float = 1.0
to_probability_distribution(states, module_output)

Convert module output to a probability distribution.

Parameters:
  • states (gfn.states.States) – The states.

  • module_output (torch.Tensor) – Output from the module, shape (batch, expected_output_dim).

Returns:

BoxCartesianDistribution instance.

Return type:

torch.distributions.Distribution

class gfn.gym.helpers.box_utils.BoxCartesianPFMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)

Bases: _BoxCartesianMLP

MLP for Box forward policy. First output is exit logit.

Parameters:
  • hidden_dim (int)

  • n_hidden_layers (int)

  • n_components (int)

  • n_dim (int)

  • kwargs (Any)

class gfn.gym.helpers.box_utils.BoxPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)

Bases: gfn.estimators.Estimator, gfn.estimators.PolicyMixin

Estimator for P_B for the Box environment.

This estimator uses the QuarterCircle(northeastern=False) distribution.

Parameters:
  • env (gfn.gym.box.BoxPolar)

  • module (torch.nn.Module)

  • n_components (int)

  • min_concentration (float)

  • max_concentration (float)

  • debug (bool)

n_components

The number of components for the mixture.

min_concentration

The minimum concentration for the Beta distributions.

max_concentration

The maximum concentration for the Beta distributions.

delta

The radius of the quarter disk.

delta: float
property expected_output_dim: int

Returns the expected output dimension of the module.

Return type:

int

max_concentration: float
min_concentration: float
module
n_components: int
to_probability_distribution(states, module_output)

Converts the module output to a probability distribution.

Parameters:
  • states (gfn.states.States) – the states for which to convert the module output to a probability distribution.

  • module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).

Returns:

The probability distribution for the states.

Return type:

torch.distributions.Distribution

classmethod uniform(env, n_components=1, **kwargs)

Create an estimator with a fixed (non-learned) uniform backward policy.

Uses alpha=beta=1 (uniform Beta) by setting skip_normalization=True so the constant module output is used directly as concentration parameters.

Parameters:
  • env (gfn.gym.box.BoxPolar) – The Box environment.

  • n_components (int) – Number of mixture components (default 1).

  • **kwargs – Extra keyword arguments forwarded to the constructor.

Returns:

A BoxPBEstimator whose module has no learnable parameters.

Return type:

BoxPBEstimator

class gfn.gym.helpers.box_utils.BoxPBMLP(hidden_dim, n_hidden_layers, n_components, **kwargs)

Bases: gfn.utils.modules.MLP

A MLP for the backward policy of the Box environment.

Parameters:
  • hidden_dim (int)

  • n_hidden_layers (int)

  • n_components (int)

  • kwargs (Any)

n_components

The number of components for each distribution parameter.

Type:

int

_input_dim: int
forward(preprocessed_states)

Computes the forward pass of the neural network.

Parameters:

preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).

Returns:

A tensor of shape (*batch_shape, 3 * n_components).

Return type:

torch.Tensor

n_components: int
class gfn.gym.helpers.box_utils.BoxPBUniform

Bases: gfn.utils.modules.UniformModule

Uniform backward policy for the polar Box environment.

Backward-compatible alias for UniformModule(output_dim=3, input_dim=2, fill_value=1.0, skip_normalization=True). Used with QuarterCircle, it leads to a uniform (alpha=beta=1) Beta distribution over parents. Prefer UniformModule for new code.

class gfn.gym.helpers.box_utils.BoxPFEstimator(env, module, n_components_s0, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)

Bases: gfn.estimators.Estimator, gfn.estimators.PolicyMixin

Estimator for P_F for the Box environment.

This estimator uses the DistributionWrapper distribution.

Parameters:
  • env (gfn.gym.box.BoxPolar)

  • module (torch.nn.Module)

  • n_components_s0 (int)

  • n_components (int)

  • min_concentration (float)

  • max_concentration (float)

  • debug (bool)

n_components_s0

The number of components for s0.

n_components

The number of components for non-s0 states.

min_concentration

The minimum concentration for the Beta distributions.

max_concentration

The maximum concentration for the Beta distributions.

delta

The radius of the quarter disk.

epsilon

The epsilon value.

_n_comp_max: int
delta: float
epsilon: float
property expected_output_dim: int

Returns the expected output dimension of the module.

Return type:

int

max_concentration: float
min_concentration: float
n_components: int
n_components_s0: int
to_probability_distribution(states, module_output)

Converts the module output to a probability distribution.

Parameters:
  • states (gfn.states.States) – the states for which to convert the module output to a probability distribution.

  • module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).

Returns:

The probability distribution for the states.

Return type:

torch.distributions.Distribution

class gfn.gym.helpers.box_utils.BoxPFMLP(hidden_dim, n_hidden_layers, n_components_s0, n_components, **kwargs)

Bases: gfn.utils.modules.MLP

A MLP for the forward policy of the Box environment.

Parameters:
  • hidden_dim (int)

  • n_hidden_layers (int)

  • n_components_s0 (int)

  • n_components (int)

  • kwargs (Any)

n_components_s0

The number of components for s0.

Type:

int

n_components

The number of components for non-s0 states.

Type:

int

PFs0

The parameters for the s0 distribution.

Type:

nn.Parameter

PFs0: torch.nn.Parameter
_input_dim: int
_n_comp_max: int
forward(preprocessed_states)

Computes the forward pass of the neural network.

Parameters:

preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).

Returns:

A tensor of shape (*batch_shape, 1 + 5 * max_n_components).

Return type:

torch.Tensor

n_components: int
n_components_s0: int
class gfn.gym.helpers.box_utils.BoxStateFlowModule(logZ_value, **kwargs)

Bases: gfn.utils.modules.MLP

A MLP for the state flow function of the Box environment.

Parameters:
  • logZ_value (torch.Tensor)

  • kwargs (Any)

logZ_value

The log partition function value.

Type:

nn.Parameter

forward(preprocessed_states)

Computes the forward pass of the neural network.

Parameters:

preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, input_dim).

Returns:

A tensor of shape (*batch_shape, output_dim).

Return type:

torch.Tensor

logZ_value: torch.nn.Parameter
class gfn.gym.helpers.box_utils.DistributionWrapper(states, delta, epsilon, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta, exit_probability, n_components, n_components_s0, debug=False)

Bases: torch.distributions.Distribution

A wrapper that combines QuarterDisk and QuarterCircleWithExit.

Parameters:
  • states (gfn.states.States)

  • delta (float)

  • epsilon (float)

  • mixture_logits (torch.Tensor)

  • alpha_r (torch.Tensor)

  • beta_r (torch.Tensor)

  • alpha_theta (torch.Tensor)

  • beta_theta (torch.Tensor)

  • exit_probability (torch.Tensor)

  • n_components (int)

  • n_components_s0 (int)

  • debug (bool)

idx_is_initial

The indices of the initial states.

idx_not_initial

The indices of the non-initial states.

quarter_disk

The QuarterDisk distribution.

quarter_circ

The QuarterCircleWithExit distribution.

_output_shape: tuple[int, Ellipsis]
debug = False
idx_is_initial: torch.Tensor
idx_not_initial: torch.Tensor
log_prob(sampled_actions)

Computes the log probability of the sampled actions.

Parameters:

sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.

Returns:

A tensor of shape (*batch_shape) containing the log probabilities.

Return type:

torch.Tensor

quarter_circ: QuarterCircleWithExit | None
quarter_disk: QuarterDisk | None
sample(sample_shape=Size())

Samples from the distribution.

Parameters:

sample_shape (torch.Size) – the shape of the samples to generate.

Returns:

A tensor of shape (sample_shape + self._output_shape) containing the sampled actions.

Return type:

torch.Tensor

class gfn.gym.helpers.box_utils.QuarterCircle(delta, northeastern, centers, mixture_logits, alpha, beta, debug=False)

Bases: torch.distributions.Distribution

Represents distributions on quarter circles.

The distributions are Mixture of Beta distributions on the possible angle range.

When a state is of norm <= delta, and northeastern=False, then the distribution is a Dirac at the state (i.e. the only possible parent is s_0).

Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py

This is useful for the Box environment.

Parameters:
  • delta (float)

  • northeastern (bool)

  • centers (gfn.states.States)

  • mixture_logits (torch.Tensor)

  • alpha (torch.Tensor)

  • beta (torch.Tensor)

  • debug (bool)

delta

The radius of the quarter disk.

northeastern

Whether the quarter disk is northeastern or southwestern.

n_states

The number of states.

n_components

The number of components in the mixture.

centers

The centers of the distribution.

base_dist

The base distribution.

min_angles

The minimum angles.

max_angles

The maximum angles.

base_dist: torch.distributions.MixtureSameFamily
centers: gfn.states.States
debug = False
delta: float
get_min_and_max_angles()

Computes the minimum and maximum angles for the distribution.

Returns:

A tuple of two tensors of shape (n_states,) containing the minimum and maximum angles, respectively.

Return type:

Tuple[torch.Tensor, torch.Tensor]

log_prob(sampled_actions)

Computes the log probability of the sampled actions.

Parameters:

sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.

Returns:

The log probability of the sampled actions as a tensor of shape batch_shape.

Return type:

torch.Tensor

max_angles: torch.Tensor
min_angles: torch.Tensor
n_components: int
n_states: int
northeastern: bool
sample(sample_shape=Size())

Samples from the distribution.

Parameters:

sample_shape (torch.Size) – the shape of the samples to generate.

Returns:

The sampled actions of shape (sample_shape, n_states, 2).

Return type:

torch.Tensor

class gfn.gym.helpers.box_utils.QuarterCircleWithExit(delta, centers, exit_probability, mixture_logits, alpha, beta, epsilon=0.0001)

Bases: torch.distributions.Distribution

Extends QuarterCircle with an exit action.

When sampling, with probability exit_probability, the exit_action [-inf, -inf] is sampled. The log_prob function is adjusted accordingly.

Parameters:
  • delta (float)

  • centers (gfn.states.States)

  • exit_probability (torch.Tensor)

  • mixture_logits (torch.Tensor)

  • alpha (torch.Tensor)

  • beta (torch.Tensor)

  • epsilon (float)

delta

The radius of the quarter disk.

epsilon

The epsilon value to consider the state as being at the border of the square.

centers

The centers of the distribution.

dist_without_exit

The distribution without the exit action.

exit_probability

The probability of exiting.

exit_action

The exit action.

n_states

The number of states.

centers: gfn.states.States
delta: float
dist_without_exit: QuarterCircle
epsilon: float
exit_action: torch.Tensor
exit_probability: torch.Tensor
log_prob(sampled_actions)

Computes the log probability of the sampled actions.

Parameters:

sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.

Returns:

The log probability of the sampled actions as a tensor of shape batch_shape.

Return type:

torch.Tensor

n_states: int
sample()

Samples from the distribution.

Returns:

The sampled actions with shape (n_states, 2).

Return type:

torch.Tensor

class gfn.gym.helpers.box_utils.QuarterDisk(delta, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta)

Bases: torch.distributions.Distribution

Represents a distribution on the northeastern quarter disk.

The radius and the angle follow Mixture of Betas distributions.

Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py

This is useful for the Box environment.

Parameters:
  • delta (float)

  • mixture_logits (torch.Tensor)

  • alpha_r (torch.Tensor)

  • beta_r (torch.Tensor)

  • alpha_theta (torch.Tensor)

  • beta_theta (torch.Tensor)

delta

The radius of the quarter disk.

mixture_logits

The logits of the mixture of Beta distributions.

base_r_dist

The base distribution for the radius.

base_theta_dist

The base distribution for the angle.

n_components

The number of components in the mixture.

base_r_dist: torch.distributions.MixtureSameFamily
base_theta_dist: torch.distributions.MixtureSameFamily
debug = False
delta: float
log_prob(sampled_actions)

Computes the log probability of the sampled actions.

Parameters:

sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.

Returns:

The log probability of the sampled actions as a tensor of shape batch_shape.

Return type:

torch.Tensor

mixture_logits: torch.Tensor
n_components: int
sample(sample_shape=torch.Size())

Samples from the distribution.

Parameters:

sample_shape (torch.Size) – the shape of the samples to generate.

Returns:

The sampled actions of shape (sample_shape, 2).

Return type:

torch.Tensor

class gfn.gym.helpers.box_utils.UniformBoxCartesianPBModule(n_components, n_dim=2)

Bases: gfn.utils.modules.UniformModule

Fixed (non-learned) backward policy module for Cartesian Box.

Backward-compatible alias for UniformModule. Prefer BoxCartesianPBEstimator.uniform(env, n_components) for new code.

Parameters:
  • n_components (int)

  • n_dim (int)

gfn.gym.helpers.box_utils.split_PF_module_output(output, n_comp_max)

Splits the module output into the expected parameter sets.

Parameters:
  • output (torch.Tensor) – the module_output from the P_F model as a tensor of shape (*batch_shape, output_dim).

  • n_comp_max (int) – the larger number of the two n_components and n_components_s0.

Returns:

  • exit_probability: A probability unique to QuarterCircleWithExit.

  • mixture_logits: Parameters shared by QuarterDisk and

    QuarterCircleWithExit.

  • alpha_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.

  • beta_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.

  • alpha_theta: Parameters unique to QuarterDisk.

  • beta_theta: Parameters unique to QuarterDisk.

Return type:

A tuple containing