Loss Functions

GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a GFlowNet. A GFlowNet includes one or multiple Estimators, at least one of which implements a to_probability_distribution function. They also need to implement a loss function, that takes as input either States, Transitions, or Trajectories Container instances, depending on the loss.

Available Losses

Trajectory Balance (TB)

Class: TBGFlowNet

The most commonly used loss. Enforces flow conservation along entire trajectories by requiring that the product of forward transition probabilities (times Z) equals the product of backward probabilities times the reward.

Requires: Forward policy (PF), backward policy (PB), learnable log-partition function (logZ).

When to use: Default choice for most problems. Works well across discrete, continuous, and graph environments. Straightforward to implement and debug.

Tip: logZ typically benefits from a higher learning rate than the policy parameters (e.g., lr_Z=0.1 vs lr=1e-3). Use separate optimizer parameter groups via gflownet.pf_pb_parameters() and gflownet.logz_parameters().

See: train_hypergrid_simple.py (basic usage), train_box.py (continuous), train_graph_ring.py (graphs).


Detailed Balance (DB)

Class: DBGFlowNet

Imposes a stricter, state-level balance constraint. Instead of balancing entire trajectories, enforces that flow is conserved at every individual transition.

Requires: Forward policy (PF), backward policy (PB), log state-flow estimator (logF) via ScalarEstimator.

When to use: When you want fine-grained per-transition learning signal. Can converge faster than TB on some problems but requires an additional estimator.

Modified variant: ModifiedDBGFlowNet drops the explicit logF estimator. In forward-looking mode, rewards must be defined on edges; the current implementation treats the edge reward as the difference between the successor and current state rewards, so only enable this when that matches your environment.

See: train_hypergrid_simple.py (with --loss db), train_bit_sequences.py.


Sub-Trajectory Balance (SubTB)

Class: SubTBGFlowNet

Generalizes TB by considering all sub-trajectories within a trajectory. Each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined here. Other strategies exist and are implemented in src/gfn/losses/sub_trajectory_balance.py.

Requires: Forward policy (PF), backward policy (PB), log state-flow estimator (logF).

When to use: When TB is underperforming and you want richer learning signal from each trajectory. Adds computational cost but can improve sample efficiency.

Note: When using geometric-based weighting, the 'mean' reduction is not supported; requests for a mean reduction are coerced to a sum (a warning is emitted when debug is enabled).

See: train_box.py (with --loss subtb), train_with_compile.py.


Flow Matching (FM)

Class: FMGFlowNet

The original GFlowNet loss. Matches incoming and outgoing flows at each state.

Requires: Only a log-flow estimator (logF) via DiscretePolicyEstimator — no explicit forward/backward policies.

When to use: Rarely recommended. Slow to compute and hard to optimize. Included primarily for completeness and for comparison with other losses.

See: train_discreteebm.py, train_ising.py.


Log Partition Variance (ZVar)

Class: LogPartitionVarianceGFlowNet

Minimizes the variance of the log-partition function estimate across trajectories. Introduced in this paper.

Requires: Forward policy (PF), backward policy (PB).

When to use: An alternative to TB that avoids learning an explicit logZ parameter. Can be useful when logZ estimation is unstable.

See: train_hypergrid.py (with --loss zvar).


Relative Trajectory Balance (RTB)

Class: RelativeTrajectoryBalanceGFlowNet

A variant of TB designed for posterior fine-tuning from a pre-trained prior. Uses a fixed reference policy that does not receive gradients.

Requires: Forward policy (PF, trainable), backward policy (PB), fixed prior policy (PF_prior).

When to use: When you have a pre-trained model (e.g., from MLE) and want to fine-tune it to match a posterior distribution.

See: train_diffusion_rtb.py (two-stage prior→posterior pipeline).


Choosing a Loss Function

Loss

Estimators needed

Learning signal

Computational cost

Recommended for

TB

PF, PB, logZ

Per-trajectory

Low

Most problems (default choice)

DB

PF, PB, logF

Per-transition

Medium

Problems where per-state signal helps

SubTB

PF, PB, logF

Per-sub-trajectory

High

When TB underperforms

FM

logF only

Per-state flow

High

Completeness / comparison

ZVar

PF, PB

Per-trajectory

Low

When logZ learning is unstable

RTB

PF, PB, PF_prior

Per-trajectory

Medium

Posterior fine-tuning

For a single-script comparison of TB, DB, and FM on the same environment, see train_hypergrid_simple.py. For all six losses in a single script, see train_hypergrid.py.

Common Training Patterns

Separate Learning Rates

Most losses benefit from different learning rates for different parameter groups:

optimizer = torch.optim.Adam([
    {"params": gflownet.pf_pb_parameters(), "lr": 1e-3},
    {"params": gflownet.logz_parameters(), "lr": 1e-1},
])

For DB/SubTB, add a third group for the logF estimator.

On-Policy vs Off-Policy

When training on-policy (no replay buffer, no exploration noise), set save_logprobs=True during sampling and recalculate_all_logprobs=False during loss computation to avoid redundant forward passes. For off-policy training, log-probs must be recalculated — see the Off-Policy Training guide.