gfn.gym.helpers.box_utils¶
Backward-compatibility shim for Box environment utilities.
Import from box_cartesian_utils or box_polar_utils directly for new code.
Classes¶
Cartesian increment distribution for Box environment. |
|
Backward Cartesian distribution for Box environment. |
|
Simplified PB estimator using Cartesian increments with back-to-source. |
|
MLP for Box backward policy. First output is back-to-start logit. |
|
Simplified PF estimator using Cartesian increments. |
|
MLP for Box forward policy. First output is exit logit. |
|
Estimator for P_B for the Box environment. |
|
A MLP for the backward policy of the Box environment. |
|
Uniform backward policy for the polar Box environment. |
|
Estimator for P_F for the Box environment. |
|
A MLP for the forward policy of the Box environment. |
|
A MLP for the state flow function of the Box environment. |
|
A wrapper that combines QuarterDisk and QuarterCircleWithExit. |
|
Represents distributions on quarter circles. |
|
Extends QuarterCircle with an exit action. |
|
Represents a distribution on the northeastern quarter disk. |
|
Fixed (non-learned) backward policy module for Cartesian Box. |
Functions¶
|
Splits the module output into the expected parameter sets. |
Module Contents¶
- class gfn.gym.helpers.box_utils.BoxCartesianDistribution(states, exit_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)¶
Bases:
_BetaMixtureMixin,torch.distributions.DistributionCartesian increment distribution for Box environment.
Uses MixtureSameFamily(Categorical, Beta) per dimension for sampling increments. Much simpler than polar coordinates - samples relative increments per dimension and converts to absolute using: action = min_incr + r * (max_range).
- Parameters:
states (gfn.states.States)
exit_logits (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
delta (float)
epsilon (float)
temperature (float)
- delta¶
Minimum step size.
- epsilon¶
Small value for numerical stability.
- alpha¶
- arg_constraints: dict¶
- at_boundary¶
- beta¶
- delta¶
- epsilon = 1e-06¶
- exit_logits_scaled¶
- is_s0¶
- log_prob(actions)¶
Compute log probability using Cartesian per-dimension approach.
- Parameters:
actions (torch.Tensor)
- Return type:
torch.Tensor
- log_weights¶
- max_range¶
- min_incr¶
- n_dim¶
- sample(sample_shape=Size())¶
Sample actions using Cartesian per-dimension increments.
- Parameters:
sample_shape (torch.Size)
- Return type:
torch.Tensor
- states¶
- class gfn.gym.helpers.box_utils.BoxCartesianPBDistribution(states, bts_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)¶
Bases:
_BetaMixtureMixin,torch.distributions.DistributionBackward Cartesian distribution for Box environment.
In torchgfn’s design, the source state is the origin [0, 0]. The BTS (back-to-source) action moves directly to s0 by setting action = state.
WHY THE LEARNED BTS BERNOULLI IS CRITICAL¶
The Trajectory Balance (TB) loss is:
L = (log P_F(τ) + log Z - log P_B(τ) - log R(x_T))^2
The BTS step (x_1 → s0) is the last step of every backward trajectory. Before this fix, BTS was always deterministic (forced): log P_B(BTS | x_1) = 0. A constant zero drops out of the gradient, so P_B received no gradient from the BTS step — regardless of trajectory length.
For 1-step trajectories (s0 → x_T, then immediately BTS back to s0), P_B was entirely gradient-free: log P_B(τ) = 0 for every such trajectory, making P_B invisible to the TB loss. This is particularly harmful because the reward landscape is dominated by states reachable from s0 in one step (the high-reward ring at |x - 0.5| ∈ (0.3, 0.4)).
With a learned Bernoulli P(BTS | x): - log P_B(BTS | x_1) = log P(BTS=1 | x_1) — gradient flows into P_B - log P_B(~BTS | x_1) = log P(BTS=0 | x_1) — also receives gradient This closes the TB loop fully: P_B now has an incentive to assign higher probability to BTS from states that are indeed close to the reward modes, and lower probability from states that are far away.
FORCED vs. STOCHASTIC BTS¶
When any dimension of the state is <= delta, the valid backward increment range [delta, state[d]] is empty for that dimension. BTS is the only valid action, so it is forced (log_prob = 0, deterministic). For all other states, BTS is an optional choice sampled from the learned Bernoulli.
- alpha¶
- any_dim_near_origin¶
- arg_constraints: dict¶
- beta¶
- bts_logits_scaled¶
- delta¶
- dim_near_origin¶
- epsilon = 1e-06¶
- log_prob(actions)¶
Compute log probability of backward actions.
Forced BTS (any_dim_near_origin): log_prob = 0 (deterministic).
Stochastic BTS: log_prob = log P(BTS=1 | s) from Bernoulli.
Non-BTS: log_prob = log P(BTS=0 | s) + log_p_beta + log_jacobian.
- Parameters:
actions (torch.Tensor)
- Return type:
torch.Tensor
- log_weights¶
- max_range¶
- min_incr¶
- n_dim¶
- sample(sample_shape=Size())¶
Sample backward actions.
BTS (action = state) is forced when near origin; otherwise sampled from the learned Bernoulli. Non-BTS samples come from Beta mixture.
- Parameters:
sample_shape (torch.Size)
- Return type:
torch.Tensor
- states¶
- Parameters:
states (gfn.states.States)
bts_logits (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
delta (float)
epsilon (float)
temperature (float)
- class gfn.gym.helpers.box_utils.BoxCartesianPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinSimplified PB estimator using Cartesian increments with back-to-source.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
numerical_epsilon (float)
debug (bool)
- delta¶
- epsilon¶
- property expected_output_dim: int¶
bts_logit + (weights + alpha + beta) * n_dim * n_comp.
- Type:
Expected output dimension
- Return type:
int
- max_concentration = 100.0¶
- min_concentration = 0.1¶
- n_components¶
- n_dim = 2¶
- numerical_epsilon = 1e-06¶
- temperature: float = 1.0¶
- to_probability_distribution(states, module_output)¶
Convert module output to backward probability distribution.
- Parameters:
states (gfn.states.States)
module_output (torch.Tensor)
- Return type:
torch.distributions.Distribution
- classmethod uniform(env, n_components, **kwargs)¶
Create an estimator with a fixed (non-learned) uniform backward policy.
- Parameters:
env (gfn.gym.box.BoxPolar) – The Box environment.
n_components (int) – Number of mixture components.
**kwargs – Extra keyword arguments forwarded to the constructor.
- Returns:
A
BoxCartesianPBEstimatorwhose module has no learnable parameters.- Return type:
- class gfn.gym.helpers.box_utils.BoxCartesianPBMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)¶
Bases:
_BoxCartesianMLPMLP for Box backward policy. First output is back-to-start logit.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
n_dim (int)
kwargs (Any)
- class gfn.gym.helpers.box_utils.BoxCartesianPFEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinSimplified PF estimator using Cartesian increments.
Much simpler than BoxPFEstimator - uses a single MLP and BoxCartesianDistribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
numerical_epsilon (float)
debug (bool)
- delta¶
- epsilon¶
- property expected_output_dim: int¶
exit_logit + (weights + alpha + beta) * n_dim * n_comp.
- Type:
Expected output dimension
- Return type:
int
- max_concentration = 100.0¶
- min_concentration = 0.1¶
- n_components¶
- n_dim = 2¶
- numerical_epsilon = 1e-06¶
- temperature: float = 1.0¶
- to_probability_distribution(states, module_output)¶
Convert module output to a probability distribution.
- Parameters:
states (gfn.states.States) – The states.
module_output (torch.Tensor) – Output from the module, shape (batch, expected_output_dim).
- Returns:
BoxCartesianDistribution instance.
- Return type:
torch.distributions.Distribution
- class gfn.gym.helpers.box_utils.BoxCartesianPFMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)¶
Bases:
_BoxCartesianMLPMLP for Box forward policy. First output is exit logit.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
n_dim (int)
kwargs (Any)
- class gfn.gym.helpers.box_utils.BoxPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinEstimator for P_B for the Box environment.
This estimator uses the QuarterCircle(northeastern=False) distribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
debug (bool)
- n_components¶
The number of components for the mixture.
- min_concentration¶
The minimum concentration for the Beta distributions.
- max_concentration¶
The maximum concentration for the Beta distributions.
- delta¶
The radius of the quarter disk.
- delta: float¶
- property expected_output_dim: int¶
Returns the expected output dimension of the module.
- Return type:
int
- max_concentration: float¶
- min_concentration: float¶
- module¶
- n_components: int¶
- to_probability_distribution(states, module_output)¶
Converts the module output to a probability distribution.
- Parameters:
states (gfn.states.States) – the states for which to convert the module output to a probability distribution.
module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).
- Returns:
The probability distribution for the states.
- Return type:
torch.distributions.Distribution
- classmethod uniform(env, n_components=1, **kwargs)¶
Create an estimator with a fixed (non-learned) uniform backward policy.
Uses
alpha=beta=1(uniform Beta) by settingskip_normalization=Trueso the constant module output is used directly as concentration parameters.- Parameters:
env (gfn.gym.box.BoxPolar) – The Box environment.
n_components (int) – Number of mixture components (default 1).
**kwargs – Extra keyword arguments forwarded to the constructor.
- Returns:
A
BoxPBEstimatorwhose module has no learnable parameters.- Return type:
- class gfn.gym.helpers.box_utils.BoxPBMLP(hidden_dim, n_hidden_layers, n_components, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the backward policy of the Box environment.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
kwargs (Any)
- n_components¶
The number of components for each distribution parameter.
- Type:
int
- _input_dim: int¶
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).
- Returns:
A tensor of shape (*batch_shape, 3 * n_components).
- Return type:
torch.Tensor
- n_components: int¶
- class gfn.gym.helpers.box_utils.BoxPBUniform¶
Bases:
gfn.utils.modules.UniformModuleUniform backward policy for the polar Box environment.
Backward-compatible alias for
UniformModule(output_dim=3, input_dim=2, fill_value=1.0, skip_normalization=True). Used withQuarterCircle, it leads to a uniform (alpha=beta=1) Beta distribution over parents. PreferUniformModulefor new code.
- class gfn.gym.helpers.box_utils.BoxPFEstimator(env, module, n_components_s0, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinEstimator for P_F for the Box environment.
This estimator uses the DistributionWrapper distribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components_s0 (int)
n_components (int)
min_concentration (float)
max_concentration (float)
debug (bool)
- n_components_s0¶
The number of components for s0.
- n_components¶
The number of components for non-s0 states.
- min_concentration¶
The minimum concentration for the Beta distributions.
- max_concentration¶
The maximum concentration for the Beta distributions.
- delta¶
The radius of the quarter disk.
- epsilon¶
The epsilon value.
- _n_comp_max: int¶
- delta: float¶
- epsilon: float¶
- property expected_output_dim: int¶
Returns the expected output dimension of the module.
- Return type:
int
- max_concentration: float¶
- min_concentration: float¶
- n_components: int¶
- n_components_s0: int¶
- to_probability_distribution(states, module_output)¶
Converts the module output to a probability distribution.
- Parameters:
states (gfn.states.States) – the states for which to convert the module output to a probability distribution.
module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).
- Returns:
The probability distribution for the states.
- Return type:
torch.distributions.Distribution
- class gfn.gym.helpers.box_utils.BoxPFMLP(hidden_dim, n_hidden_layers, n_components_s0, n_components, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the forward policy of the Box environment.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components_s0 (int)
n_components (int)
kwargs (Any)
- n_components_s0¶
The number of components for s0.
- Type:
int
- n_components¶
The number of components for non-s0 states.
- Type:
int
- PFs0¶
The parameters for the s0 distribution.
- Type:
nn.Parameter
- PFs0: torch.nn.Parameter¶
- _input_dim: int¶
- _n_comp_max: int¶
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).
- Returns:
A tensor of shape (*batch_shape, 1 + 5 * max_n_components).
- Return type:
torch.Tensor
- n_components: int¶
- n_components_s0: int¶
- class gfn.gym.helpers.box_utils.BoxStateFlowModule(logZ_value, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the state flow function of the Box environment.
- Parameters:
logZ_value (torch.Tensor)
kwargs (Any)
- logZ_value¶
The log partition function value.
- Type:
nn.Parameter
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, input_dim).
- Returns:
A tensor of shape (*batch_shape, output_dim).
- Return type:
torch.Tensor
- logZ_value: torch.nn.Parameter¶
- class gfn.gym.helpers.box_utils.DistributionWrapper(states, delta, epsilon, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta, exit_probability, n_components, n_components_s0, debug=False)¶
Bases:
torch.distributions.DistributionA wrapper that combines QuarterDisk and QuarterCircleWithExit.
- Parameters:
states (gfn.states.States)
delta (float)
epsilon (float)
mixture_logits (torch.Tensor)
alpha_r (torch.Tensor)
beta_r (torch.Tensor)
alpha_theta (torch.Tensor)
beta_theta (torch.Tensor)
exit_probability (torch.Tensor)
n_components (int)
n_components_s0 (int)
debug (bool)
- idx_is_initial¶
The indices of the initial states.
- idx_not_initial¶
The indices of the non-initial states.
- quarter_disk¶
The QuarterDisk distribution.
- quarter_circ¶
The QuarterCircleWithExit distribution.
- _output_shape: tuple[int, Ellipsis]¶
- debug = False¶
- idx_is_initial: torch.Tensor¶
- idx_not_initial: torch.Tensor¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
A tensor of shape (*batch_shape) containing the log probabilities.
- Return type:
torch.Tensor
- quarter_circ: QuarterCircleWithExit | None¶
- quarter_disk: QuarterDisk | None¶
- sample(sample_shape=Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
A tensor of shape (sample_shape + self._output_shape) containing the sampled actions.
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_utils.QuarterCircle(delta, northeastern, centers, mixture_logits, alpha, beta, debug=False)¶
Bases:
torch.distributions.DistributionRepresents distributions on quarter circles.
The distributions are Mixture of Beta distributions on the possible angle range.
When a state is of norm <= delta, and northeastern=False, then the distribution is a Dirac at the state (i.e. the only possible parent is s_0).
Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py
This is useful for the Box environment.
- Parameters:
delta (float)
northeastern (bool)
centers (gfn.states.States)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
debug (bool)
- delta¶
The radius of the quarter disk.
- northeastern¶
Whether the quarter disk is northeastern or southwestern.
- n_states¶
The number of states.
- n_components¶
The number of components in the mixture.
- centers¶
The centers of the distribution.
- base_dist¶
The base distribution.
- min_angles¶
The minimum angles.
- max_angles¶
The maximum angles.
- base_dist: torch.distributions.MixtureSameFamily¶
- centers: gfn.states.States¶
- debug = False¶
- delta: float¶
- get_min_and_max_angles()¶
Computes the minimum and maximum angles for the distribution.
- Returns:
A tuple of two tensors of shape (n_states,) containing the minimum and maximum angles, respectively.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- max_angles: torch.Tensor¶
- min_angles: torch.Tensor¶
- n_components: int¶
- n_states: int¶
- northeastern: bool¶
- sample(sample_shape=Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
The sampled actions of shape (sample_shape, n_states, 2).
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_utils.QuarterCircleWithExit(delta, centers, exit_probability, mixture_logits, alpha, beta, epsilon=0.0001)¶
Bases:
torch.distributions.DistributionExtends QuarterCircle with an exit action.
When sampling, with probability exit_probability, the exit_action [-inf, -inf] is sampled. The log_prob function is adjusted accordingly.
- Parameters:
delta (float)
centers (gfn.states.States)
exit_probability (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
epsilon (float)
- delta¶
The radius of the quarter disk.
- epsilon¶
The epsilon value to consider the state as being at the border of the square.
- centers¶
The centers of the distribution.
- dist_without_exit¶
The distribution without the exit action.
- exit_probability¶
The probability of exiting.
- exit_action¶
The exit action.
- n_states¶
The number of states.
- centers: gfn.states.States¶
- delta: float¶
- dist_without_exit: QuarterCircle¶
- epsilon: float¶
- exit_action: torch.Tensor¶
- exit_probability: torch.Tensor¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- n_states: int¶
- sample()¶
Samples from the distribution.
- Returns:
The sampled actions with shape (n_states, 2).
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_utils.QuarterDisk(delta, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta)¶
Bases:
torch.distributions.DistributionRepresents a distribution on the northeastern quarter disk.
The radius and the angle follow Mixture of Betas distributions.
Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py
This is useful for the Box environment.
- Parameters:
delta (float)
mixture_logits (torch.Tensor)
alpha_r (torch.Tensor)
beta_r (torch.Tensor)
alpha_theta (torch.Tensor)
beta_theta (torch.Tensor)
- delta¶
The radius of the quarter disk.
- mixture_logits¶
The logits of the mixture of Beta distributions.
- base_r_dist¶
The base distribution for the radius.
- base_theta_dist¶
The base distribution for the angle.
- n_components¶
The number of components in the mixture.
- base_r_dist: torch.distributions.MixtureSameFamily¶
- base_theta_dist: torch.distributions.MixtureSameFamily¶
- debug = False¶
- delta: float¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- mixture_logits: torch.Tensor¶
- n_components: int¶
- sample(sample_shape=torch.Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
The sampled actions of shape (sample_shape, 2).
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_utils.UniformBoxCartesianPBModule(n_components, n_dim=2)¶
Bases:
gfn.utils.modules.UniformModuleFixed (non-learned) backward policy module for Cartesian Box.
Backward-compatible alias for
UniformModule. PreferBoxCartesianPBEstimator.uniform(env, n_components)for new code.- Parameters:
n_components (int)
n_dim (int)
- gfn.gym.helpers.box_utils.split_PF_module_output(output, n_comp_max)¶
Splits the module output into the expected parameter sets.
- Parameters:
output (torch.Tensor) – the module_output from the P_F model as a tensor of shape (*batch_shape, output_dim).
n_comp_max (int) – the larger number of the two n_components and n_components_s0.
- Returns:
exit_probability: A probability unique to QuarterCircleWithExit.
- mixture_logits: Parameters shared by QuarterDisk and
QuarterCircleWithExit.
alpha_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.
beta_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.
alpha_theta: Parameters unique to QuarterDisk.
beta_theta: Parameters unique to QuarterDisk.
- Return type:
A tuple containing