gfn.gflownet.losses

Pluggable regression losses for GFlowNet balance conditions.

All GFlowNet training objectives (TB, DB, SubTB, FM, RTB, etc.) minimize a balance condition residual. The standard approach squares this residual, which corresponds to minimizing the reverse KL divergence between the learned and target distributions.

This module provides alternative loss functions that correspond to different divergence measures, following Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).

Each loss g(t) is applied elementwise to the residual t and satisfies:
  • g(0) = 0 (zero loss at balance)

  • g(t) >= 0 for all t (non-negative)

  • g'(0) = 0 (stationary point at balance)

Hu et al. Theorem 4.1 shows that each regression loss g induces an f-divergence between the learned flow and the target, where the f-divergence generator is f(u) = u * integral_1^u [g'(log s) / s^2] ds.

Zero-forcing losses (like squared error) penalize the learner for placing mass where the target has none — they tend to undercover modes. Zero-avoiding losses penalize the learner for missing mass where the target has some — they tend to overcover and explore more modes.

Usage:

from gfn.gflownet import TBGFlowNet
from gfn.gflownet.losses import ShiftedCoshLoss

gfn = TBGFlowNet(pf=pf, pb=pb, loss_fn=ShiftedCoshLoss())

Classes

HalfSquaredLoss

Half squared loss: \(g(t) = \tfrac{1}{2} t^2\).

LinexLoss

Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).

RegressionLoss

Abstract base for regression losses on GFlowNet balance residuals.

ShiftedCoshLoss

Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).

SquaredLoss

Standard squared loss: \(g(t) = t^2\).

Module Contents

class gfn.gflownet.losses.HalfSquaredLoss

Bases: RegressionLoss

Half squared loss: \(g(t) = \tfrac{1}{2} t^2\).

The \(\tfrac{1}{2}\) factor ensures the gradient equals the residual itself: \(g'(t) = t\) rather than \(2t\). This is the standard least-squares convention (minimizing \(\tfrac{1}{2}\|r\|^2\) so the normal equations have no factor of 2), and matches the RTB formulation in Venkatraman et al. (2024).

This is the default loss for RelativeTrajectoryBalanceGFlowNet and RelativeLogPartitionVarianceGFlowNet.

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

class gfn.gflownet.losses.LinexLoss(alpha=1.0)

Bases: RegressionLoss

Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).

The alpha parameter controls the asymmetry:

  • alpha = 1: corresponds to the forward KL divergence. Zero-avoiding (mass-covering / exploration-favoring): penalizes the learner for missing mass where the target has support, encouraging broader mode coverage at the cost of some spurious mass.

  • alpha = 0.5: corresponds to the alpha-divergence with alpha = 0.5. Balanced: neither purely zero-forcing nor zero-avoiding.

  • alpha < 0: becomes zero-forcing (mode-seeking), similar to but distinct from squared loss.

The \(1/\alpha^2\) normalization ensures g''(0) = 1 for all alpha, matching the curvature of squared loss near zero.

References

Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).

The Linex loss originates from Bayesian decision theory: Varian (1975), Zellner (1986).

Parameters:

alpha (float)

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

__eq__(other)
Parameters:

other (object)

Return type:

bool

__hash__()
Return type:

int

__repr__()
Return type:

str

alpha = 1.0
class gfn.gflownet.losses.RegressionLoss

Bases: abc.ABC

Abstract base for regression losses on GFlowNet balance residuals.

Subclasses implement __call__ mapping a residual tensor to a non-negative loss tensor of the same shape.

abstract __call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

__eq__(other)
Parameters:

other (object)

Return type:

bool

__hash__()
Return type:

int

__repr__()
Return type:

str

class gfn.gflownet.losses.ShiftedCoshLoss

Bases: RegressionLoss

Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).

This is the only loss in the family that is simultaneously zero-forcing (penalizes spurious mass) and zero-avoiding (penalizes missing modes). It is symmetric: g(t) = g(-t).

Near t = 0 it behaves like t^2 (same curvature as squared loss), but for large |t| it grows exponentially, providing stronger gradients for poorly-matched trajectories.

Hu et al. (ICLR 2025) found this loss generally outperforms squared error on convergence speed and mode coverage across HyperGrid, bit-sequence, and sEH molecule benchmarks.

References

Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

class gfn.gflownet.losses.SquaredLoss

Bases: RegressionLoss

Standard squared loss: \(g(t) = t^2\).

Corresponds to the reverse KL divergence (Malkin et al. 2022). This is zero-forcing (mode-seeking): it penalizes the learner for placing probability mass where the target has none, but does not penalize missing modes. This can lead to mode collapse in multi-modal targets.

This is the default loss for TB, DB, SubTB, LPV, and FM classes, reproducing the standard behavior from the literature.

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor