gfn.gflownet.trajectory_balance

Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).

Classes

LogPartitionVarianceGFlowNet

GFlowNet for the Log Partition Variance loss.

RelativeLogPartitionVarianceGFlowNet

RTB variant that eliminates the learned logZ via variance minimization.

RelativeTBBase

Shared base for Relative Trajectory Balance variants.

RelativeTrajectoryBalanceGFlowNet

GFlowNet for the Relative Trajectory Balance (RTB) loss.

TBGFlowNet

GFlowNet for the Trajectory Balance loss.

TrustPCLGFlowNet

Trust-PCL view of Relative Trajectory Balance.

Module Contents

class gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Log Partition Variance loss.

The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator.

constant_pb

Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

log_reward_clip_min

If finite, clips log rewards to this value.

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the log partition variance loss.

The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed log partition variance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

class gfn.gflownet.trajectory_balance.RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTBBase

RTB variant that eliminates the learned logZ via variance minimization.

Analogous to how LogPartitionVarianceGFlowNet relates to TBGFlowNet, this class mean-centers the RTB residuals within each batch so that no explicit logZ parameter is needed.

The loss minimizes

\[\operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr],\]

which equals the RTB loss evaluated at the batch-optimal \(\log Z^* = -\overline{s}\) (the negative batch mean of scores).

Parameters:
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the Relative LPV loss on a batch of trajectories.

Parameters:
Return type:

torch.Tensor

class gfn.gflownet.trajectory_balance.RelativeTBBase(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Shared base for Relative Trajectory Balance variants.

Manages the prior forward policy and beta scaling. Subclasses only need to implement loss() (deciding how to handle logZ and reduction).

Parameters:
_compute_rtb_scores(env, trajectories, log_rewards=None, recalculate_all_logprobs=True)

RTB residuals: log_pf_post - log_pf_prior - beta * log_rewards.

Parameters:
  • env (gfn.env.Env | None) – The environment (unused, kept for API consistency).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

Shape (N,) per-trajectory scores.

Return type:

torch.Tensor

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

RTB residuals (without logZ): log_pf_post - log_pf_prior - beta * log_R.

This is the public interface to the RTB balance residuals, analogous to TrajectoryBasedGFlowNet.get_scores() for standard TB.

Returns:

Shape (N,) per-trajectory scores.

Parameters:
Return type:

torch.Tensor

property prior_pf: gfn.estimators.Estimator

The fixed prior forward policy (not registered as a submodule).

Return type:

gfn.estimators.Estimator

class gfn.gflownet.trajectory_balance.RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ=None, init_logZ=0.0, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTBBase

GFlowNet for the Relative Trajectory Balance (RTB) loss.

This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing

\[\left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2,\]

where \(p_\theta\) is a fixed prior process, \(p_\phi\) is the learnable posterior, \(r\) is a positive reward/constraint on the terminal state \(x_T\), and \(\log Z_\phi\) is a learned scalar normalizer.

Parameters:
logZ
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the RTB loss on a batch of trajectories.

Parameters:
Return type:

torch.Tensor

class gfn.gflownet.trajectory_balance.TBGFlowNet(pf, pb, logZ=None, init_logZ=0.0, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Trajectory Balance loss.

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.

See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

logZ

A learnable parameter or a ScalarEstimator instance (for conditional GFNs).

constant_pb

Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

log_reward_clip_min

If finite, clips log rewards to this value.

logZ
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the trajectory balance loss.

The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed trajectory balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

class gfn.gflownet.trajectory_balance.TrustPCLGFlowNet(policy, reference_policy, *, alpha=1.0, init_v_soft_s0=0.0, logZ=None, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTrajectoryBalanceGFlowNet

Trust-PCL view of Relative Trajectory Balance.

Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an RL-native interface to the same algorithm, using reinforcement learning terminology.

The equivalence (Proposition 3.1 of Deleu et al.):

\[\mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi)\]

where \(\alpha = 1/\beta\) is the Trust-PCL temperature.

Parameter correspondence:

Interpretation of the learned scalar:

In RTB, logZ estimates the log-partition function \(\log \int p_\theta(x)\,r(x)\,dx\). In Trust-PCL, the same quantity is the soft value function at the initial state: \(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\). This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation.

Why this class exists:

The underlying computation is identical to RelativeTrajectoryBalanceGFlowNet (the loss is just scaled by \(\alpha^2\)). This class exists to:

  1. Provide an RL-native constructor (policy, reference_policy, alpha, init_v_soft_s0) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL.

  2. Expose alpha and v_soft_s0 properties for interpretability and monitoring.

  3. Serve as a pedagogical bridge between the GFlowNet and RL communities.

References

Deleu et al. “Relative Trajectory Balance is equivalent to Trust-PCL” (2025, arXiv:2509.01632).

Nachum et al. “Trust-PCL: An Off-Policy Trust Region Method for Continuous Control” (NeurIPS 2017, arXiv:1707.01891).

Venkatraman et al. “Amortizing intractable inference in diffusion models for vision, language, and control” (NeurIPS 2024, arXiv:2405.20971).

Parameters:
property alpha: torch.Tensor

\(\alpha = 1/\beta\).

Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies:

\[\pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr)\]

Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior.

Type:

Trust-PCL temperature

Return type:

torch.Tensor

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the Trust-PCL loss: \(\alpha^2 \cdot \mathcal{L}_{\text{RTB}}\).

The scaling by \(\alpha^2\) is the only difference from RelativeTrajectoryBalanceGFlowNet.loss(). It ensures gradient magnitudes match the Trust-PCL formulation.

Parameters:
Return type:

torch.Tensor

property v_soft_s0: torch.Tensor

\(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\).

This is the expected return under the optimal entropy-regularized policy, starting from \(s_0\):

\[V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right]\]

The KL regularization term \(\alpha \log(\pi_{\text{ref}} / \pi_\phi)\) in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition.

Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals \(\alpha \log \int p_\\theta(x)\,r(x)\,dx\).

Type:

Soft value function at the initial state

Return type:

torch.Tensor