gfn.gym.helpers.box_cartesian_utils¶
Cartesian increment estimators and distributions for the Box environment.
Classes¶
Cartesian increment distribution for Box environment. |
|
Backward Cartesian distribution for Box environment. |
|
Simplified PB estimator using Cartesian increments with back-to-source. |
|
MLP for Box backward policy. First output is back-to-start logit. |
|
Simplified PF estimator using Cartesian increments. |
|
MLP for Box forward policy. First output is exit logit. |
|
Fixed (non-learned) backward policy module for Cartesian Box. |
|
Shared Beta mixture sampling/log_prob for forward and backward distributions. |
|
Base MLP for Box Cartesian policies (forward and backward). |
Module Contents¶
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianDistribution(states, exit_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)¶
Bases:
_BetaMixtureMixin,torch.distributions.DistributionCartesian increment distribution for Box environment.
Uses MixtureSameFamily(Categorical, Beta) per dimension for sampling increments. Much simpler than polar coordinates - samples relative increments per dimension and converts to absolute using: action = min_incr + r * (max_range).
- Parameters:
states (gfn.states.States)
exit_logits (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
delta (float)
epsilon (float)
temperature (float)
- delta¶
Minimum step size.
- epsilon¶
Small value for numerical stability.
- alpha¶
- arg_constraints: dict¶
- at_boundary¶
- beta¶
- delta¶
- epsilon = 1e-06¶
- exit_logits_scaled¶
- is_s0¶
- log_prob(actions)¶
Compute log probability using Cartesian per-dimension approach.
- Parameters:
actions (torch.Tensor)
- Return type:
torch.Tensor
- log_weights¶
- max_range¶
- min_incr¶
- n_dim¶
- sample(sample_shape=Size())¶
Sample actions using Cartesian per-dimension increments.
- Parameters:
sample_shape (torch.Size)
- Return type:
torch.Tensor
- states¶
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianPBDistribution(states, bts_logits, mixture_logits, alpha, beta, delta, epsilon=1e-06, temperature=1.0)¶
Bases:
_BetaMixtureMixin,torch.distributions.DistributionBackward Cartesian distribution for Box environment.
In torchgfn’s design, the source state is the origin [0, 0]. The BTS (back-to-source) action moves directly to s0 by setting action = state.
WHY THE LEARNED BTS BERNOULLI IS CRITICAL¶
The Trajectory Balance (TB) loss is:
L = (log P_F(τ) + log Z - log P_B(τ) - log R(x_T))^2
The BTS step (x_1 → s0) is the last step of every backward trajectory. Before this fix, BTS was always deterministic (forced): log P_B(BTS | x_1) = 0. A constant zero drops out of the gradient, so P_B received no gradient from the BTS step — regardless of trajectory length.
For 1-step trajectories (s0 → x_T, then immediately BTS back to s0), P_B was entirely gradient-free: log P_B(τ) = 0 for every such trajectory, making P_B invisible to the TB loss. This is particularly harmful because the reward landscape is dominated by states reachable from s0 in one step (the high-reward ring at |x - 0.5| ∈ (0.3, 0.4)).
With a learned Bernoulli P(BTS | x): - log P_B(BTS | x_1) = log P(BTS=1 | x_1) — gradient flows into P_B - log P_B(~BTS | x_1) = log P(BTS=0 | x_1) — also receives gradient This closes the TB loop fully: P_B now has an incentive to assign higher probability to BTS from states that are indeed close to the reward modes, and lower probability from states that are far away.
FORCED vs. STOCHASTIC BTS¶
When any dimension of the state is <= delta, the valid backward increment range [delta, state[d]] is empty for that dimension. BTS is the only valid action, so it is forced (log_prob = 0, deterministic). For all other states, BTS is an optional choice sampled from the learned Bernoulli.
- alpha¶
- any_dim_near_origin¶
- arg_constraints: dict¶
- beta¶
- bts_logits_scaled¶
- delta¶
- dim_near_origin¶
- epsilon = 1e-06¶
- log_prob(actions)¶
Compute log probability of backward actions.
Forced BTS (any_dim_near_origin): log_prob = 0 (deterministic).
Stochastic BTS: log_prob = log P(BTS=1 | s) from Bernoulli.
Non-BTS: log_prob = log P(BTS=0 | s) + log_p_beta + log_jacobian.
- Parameters:
actions (torch.Tensor)
- Return type:
torch.Tensor
- log_weights¶
- max_range¶
- min_incr¶
- n_dim¶
- sample(sample_shape=Size())¶
Sample backward actions.
BTS (action = state) is forced when near origin; otherwise sampled from the learned Bernoulli. Non-BTS samples come from Beta mixture.
- Parameters:
sample_shape (torch.Size)
- Return type:
torch.Tensor
- states¶
- Parameters:
states (gfn.states.States)
bts_logits (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
delta (float)
epsilon (float)
temperature (float)
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinSimplified PB estimator using Cartesian increments with back-to-source.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
numerical_epsilon (float)
debug (bool)
- delta¶
- epsilon¶
- property expected_output_dim: int¶
bts_logit + (weights + alpha + beta) * n_dim * n_comp.
- Type:
Expected output dimension
- Return type:
int
- max_concentration = 100.0¶
- min_concentration = 0.1¶
- n_components¶
- n_dim = 2¶
- numerical_epsilon = 1e-06¶
- temperature: float = 1.0¶
- to_probability_distribution(states, module_output)¶
Convert module output to backward probability distribution.
- Parameters:
states (gfn.states.States)
module_output (torch.Tensor)
- Return type:
torch.distributions.Distribution
- classmethod uniform(env, n_components, **kwargs)¶
Create an estimator with a fixed (non-learned) uniform backward policy.
- Parameters:
env (gfn.gym.box.BoxPolar) – The Box environment.
n_components (int) – Number of mixture components.
**kwargs – Extra keyword arguments forwarded to the constructor.
- Returns:
A
BoxCartesianPBEstimatorwhose module has no learnable parameters.- Return type:
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianPBMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)¶
Bases:
_BoxCartesianMLPMLP for Box backward policy. First output is back-to-start logit.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
n_dim (int)
kwargs (Any)
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianPFEstimator(env, module, n_components, min_concentration=0.1, max_concentration=100.0, numerical_epsilon=1e-06, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinSimplified PF estimator using Cartesian increments.
Much simpler than BoxPFEstimator - uses a single MLP and BoxCartesianDistribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
numerical_epsilon (float)
debug (bool)
- delta¶
- epsilon¶
- property expected_output_dim: int¶
exit_logit + (weights + alpha + beta) * n_dim * n_comp.
- Type:
Expected output dimension
- Return type:
int
- max_concentration = 100.0¶
- min_concentration = 0.1¶
- n_components¶
- n_dim = 2¶
- numerical_epsilon = 1e-06¶
- temperature: float = 1.0¶
- to_probability_distribution(states, module_output)¶
Convert module output to a probability distribution.
- Parameters:
states (gfn.states.States) – The states.
module_output (torch.Tensor) – Output from the module, shape (batch, expected_output_dim).
- Returns:
BoxCartesianDistribution instance.
- Return type:
torch.distributions.Distribution
- class gfn.gym.helpers.box_cartesian_utils.BoxCartesianPFMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)¶
Bases:
_BoxCartesianMLPMLP for Box forward policy. First output is exit logit.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
n_dim (int)
kwargs (Any)
- class gfn.gym.helpers.box_cartesian_utils.UniformBoxCartesianPBModule(n_components, n_dim=2)¶
Bases:
gfn.utils.modules.UniformModuleFixed (non-learned) backward policy module for Cartesian Box.
Backward-compatible alias for
UniformModule. PreferBoxCartesianPBEstimator.uniform(env, n_components)for new code.- Parameters:
n_components (int)
n_dim (int)
- class gfn.gym.helpers.box_cartesian_utils._BetaMixtureMixin¶
Shared Beta mixture sampling/log_prob for forward and backward distributions.
- _beta_mixture_log_prob(r)¶
Compute Beta mixture log_prob without MixtureSameFamily overhead.
- Parameters:
r (torch.Tensor)
- Return type:
torch.Tensor
- _sample_beta_mixture()¶
Sample from Beta mixture without MixtureSameFamily overhead.
- Return type:
torch.Tensor
- alpha: torch.Tensor¶
- beta: torch.Tensor¶
- log_weights: torch.Tensor¶
- class gfn.gym.helpers.box_cartesian_utils._BoxCartesianMLP(hidden_dim, n_hidden_layers, n_components, n_dim=2, **kwargs)¶
Bases:
gfn.utils.modules.MLPBase MLP for Box Cartesian policies (forward and backward).
Output format: [logit, mixture_logits…, alpha…, beta…] where the first logit is exit (PF) or back-to-start (PB), and mixture_logits, alpha, beta each have shape n_dim * n_components.
States are normalized from [0, 1] to [-1, 1] before the forward pass to match the gflownet reference (states2policy normalization).
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
n_dim (int)
kwargs (Any)
- forward(preprocessed_states)¶
Forward pass. Normalizes [0, 1] states to [-1, 1] before the MLP.
- Parameters:
preprocessed_states (torch.Tensor)
- Return type:
torch.Tensor
- n_components¶
- n_dim = 2¶