gfn.gflownet.losses¶
Pluggable regression losses for GFlowNet balance conditions.
All GFlowNet training objectives (TB, DB, SubTB, FM, RTB, etc.) minimize a balance condition residual. The standard approach squares this residual, which corresponds to minimizing the reverse KL divergence between the learned and target distributions.
This module provides alternative loss functions that correspond to different divergence measures, following Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).
- Each loss
g(t)is applied elementwise to the residualtand satisfies: g(0) = 0(zero loss at balance)g(t) >= 0for allt(non-negative)g'(0) = 0(stationary point at balance)
Hu et al. Theorem 4.1 shows that each regression loss g induces an
f-divergence between the learned flow and the target, where the f-divergence
generator is f(u) = u * integral_1^u [g'(log s) / s^2] ds.
Zero-forcing losses (like squared error) penalize the learner for placing mass where the target has none — they tend to undercover modes. Zero-avoiding losses penalize the learner for missing mass where the target has some — they tend to overcover and explore more modes.
Usage:
from gfn.gflownet import TBGFlowNet
from gfn.gflownet.losses import ShiftedCoshLoss
gfn = TBGFlowNet(pf=pf, pb=pb, loss_fn=ShiftedCoshLoss())
Classes¶
Half squared loss: \(g(t) = \tfrac{1}{2} t^2\). |
|
Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\). |
|
Abstract base for regression losses on GFlowNet balance residuals. |
|
Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\). |
|
Standard squared loss: \(g(t) = t^2\). |
Module Contents¶
- class gfn.gflownet.losses.HalfSquaredLoss¶
Bases:
RegressionLossHalf squared loss: \(g(t) = \tfrac{1}{2} t^2\).
The \(\tfrac{1}{2}\) factor ensures the gradient equals the residual itself: \(g'(t) = t\) rather than \(2t\). This is the standard least-squares convention (minimizing \(\tfrac{1}{2}\|r\|^2\) so the normal equations have no factor of 2), and matches the RTB formulation in Venkatraman et al. (2024).
This is the default loss for
RelativeTrajectoryBalanceGFlowNetandRelativeLogPartitionVarianceGFlowNet.- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- class gfn.gflownet.losses.LinexLoss(alpha=1.0)¶
Bases:
RegressionLossLinear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).
The
alphaparameter controls the asymmetry:alpha = 1: corresponds to the forward KL divergence. Zero-avoiding (mass-covering / exploration-favoring): penalizes the learner for missing mass where the target has support, encouraging broader mode coverage at the cost of some spurious mass.alpha = 0.5: corresponds to the alpha-divergence withalpha = 0.5. Balanced: neither purely zero-forcing nor zero-avoiding.alpha < 0: becomes zero-forcing (mode-seeking), similar to but distinct from squared loss.
The \(1/\alpha^2\) normalization ensures
g''(0) = 1for allalpha, matching the curvature of squared loss near zero.References
Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).
The Linex loss originates from Bayesian decision theory: Varian (1975), Zellner (1986).
- Parameters:
alpha (float)
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- __eq__(other)¶
- Parameters:
other (object)
- Return type:
bool
- __hash__()¶
- Return type:
int
- __repr__()¶
- Return type:
str
- alpha = 1.0¶
- class gfn.gflownet.losses.RegressionLoss¶
Bases:
abc.ABCAbstract base for regression losses on GFlowNet balance residuals.
Subclasses implement
__call__mapping a residual tensor to a non-negative loss tensor of the same shape.- abstract __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- __eq__(other)¶
- Parameters:
other (object)
- Return type:
bool
- __hash__()¶
- Return type:
int
- __repr__()¶
- Return type:
str
- class gfn.gflownet.losses.ShiftedCoshLoss¶
Bases:
RegressionLossShifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).
This is the only loss in the family that is simultaneously zero-forcing (penalizes spurious mass) and zero-avoiding (penalizes missing modes). It is symmetric:
g(t) = g(-t).Near
t = 0it behaves liket^2(same curvature as squared loss), but for large|t|it grows exponentially, providing stronger gradients for poorly-matched trajectories.Hu et al. (ICLR 2025) found this loss generally outperforms squared error on convergence speed and mode coverage across HyperGrid, bit-sequence, and sEH molecule benchmarks.
References
Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- class gfn.gflownet.losses.SquaredLoss¶
Bases:
RegressionLossStandard squared loss: \(g(t) = t^2\).
Corresponds to the reverse KL divergence (Malkin et al. 2022). This is zero-forcing (mode-seeking): it penalizes the learner for placing probability mass where the target has none, but does not penalize missing modes. This can lead to mode collapse in multi-modal targets.
This is the default loss for TB, DB, SubTB, LPV, and FM classes, reproducing the standard behavior from the literature.
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor