gfn.gflownet.trajectory_balance¶
Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446).
Classes¶
GFlowNet for the Log Partition Variance loss. |
|
RTB variant that eliminates the learned logZ via variance minimization. |
|
Shared base for Relative Trajectory Balance variants. |
|
GFlowNet for the Relative Trajectory Balance (RTB) loss. |
|
GFlowNet for the Trajectory Balance loss. |
|
Trust-PCL view of Relative Trajectory Balance. |
Module Contents¶
- class gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Log Partition Variance loss.
The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator.
- constant_pb¶
Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the log partition variance loss.
The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).
trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed log partition variance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- class gfn.gflownet.trajectory_balance.RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTBBaseRTB variant that eliminates the learned logZ via variance minimization.
Analogous to how
LogPartitionVarianceGFlowNetrelates toTBGFlowNet, this class mean-centers the RTB residuals within each batch so that no explicitlogZparameter is needed.The loss minimizes
\[\operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr],\]which equals the RTB loss evaluated at the batch-optimal \(\log Z^* = -\overline{s}\) (the negative batch mean of scores).
- Parameters:
prior_pf (gfn.estimators.Estimator)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the Relative LPV loss on a batch of trajectories.
- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- class gfn.gflownet.trajectory_balance.RelativeTBBase(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetShared base for Relative Trajectory Balance variants.
Manages the prior forward policy and
betascaling. Subclasses only need to implementloss()(deciding how to handlelogZand reduction).- Parameters:
prior_pf (gfn.estimators.Estimator)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- _compute_rtb_scores(env, trajectories, log_rewards=None, recalculate_all_logprobs=True)¶
RTB residuals:
log_pf_post - log_pf_prior - beta * log_rewards.- Parameters:
env (gfn.env.Env | None) – The environment (unused, kept for API consistency).
trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
Shape
(N,)per-trajectory scores.- Return type:
torch.Tensor
- get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)¶
RTB residuals (without logZ):
log_pf_post - log_pf_prior - beta * log_R.This is the public interface to the RTB balance residuals, analogous to
TrajectoryBasedGFlowNet.get_scores()for standard TB.- Returns:
Shape
(N,)per-trajectory scores.- Parameters:
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
env (gfn.env.Env | None)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- property prior_pf: gfn.estimators.Estimator¶
The fixed prior forward policy (not registered as a submodule).
- Return type:
- class gfn.gflownet.trajectory_balance.RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ=None, init_logZ=0.0, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTBBaseGFlowNet for the Relative Trajectory Balance (RTB) loss.
This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing
\[\left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2,\]where \(p_\theta\) is a fixed prior process, \(p_\phi\) is the learnable posterior, \(r\) is a positive reward/constraint on the terminal state \(x_T\), and \(\log Z_\phi\) is a learned scalar normalizer.
- Parameters:
prior_pf (gfn.estimators.Estimator)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
init_logZ (float)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- logZ¶
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the RTB loss on a batch of trajectories.
- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- class gfn.gflownet.trajectory_balance.TBGFlowNet(pf, pb, logZ=None, init_logZ=0.0, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Trajectory Balance loss.
\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.
See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details.
- Parameters:
pb (gfn.estimators.Estimator | None)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
init_logZ (float)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- logZ¶
A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
- constant_pb¶
Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- logZ¶
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the trajectory balance loss.
The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).
trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed trajectory balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- class gfn.gflownet.trajectory_balance.TrustPCLGFlowNet(policy, reference_policy, *, alpha=1.0, init_v_soft_s0=0.0, logZ=None, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTrajectoryBalanceGFlowNetTrust-PCL view of Relative Trajectory Balance.
Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an RL-native interface to the same algorithm, using reinforcement learning terminology.
The equivalence (Proposition 3.1 of Deleu et al.):
\[\mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi)\]where \(\alpha = 1/\beta\) is the Trust-PCL temperature.
Parameter correspondence:
Interpretation of the learned scalar:
In RTB,
logZestimates the log-partition function \(\log \int p_\theta(x)\,r(x)\,dx\). In Trust-PCL, the same quantity is the soft value function at the initial state: \(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\). This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation.Why this class exists:
The underlying computation is identical to
RelativeTrajectoryBalanceGFlowNet(the loss is just scaled by \(\alpha^2\)). This class exists to:Provide an RL-native constructor (
policy,reference_policy,alpha,init_v_soft_s0) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL.Expose
alphaandv_soft_s0properties for interpretability and monitoring.Serve as a pedagogical bridge between the GFlowNet and RL communities.
References
Deleu et al. “Relative Trajectory Balance is equivalent to Trust-PCL” (2025, arXiv:2509.01632).
Nachum et al. “Trust-PCL: An Off-Policy Trust Region Method for Continuous Control” (NeurIPS 2017, arXiv:1707.01891).
Venkatraman et al. “Amortizing intractable inference in diffusion models for vision, language, and control” (NeurIPS 2024, arXiv:2405.20971).
- Parameters:
policy (gfn.estimators.Estimator)
reference_policy (gfn.estimators.Estimator)
alpha (float)
init_v_soft_s0 (float)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- property alpha: torch.Tensor¶
\(\alpha = 1/\beta\).
Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies:
\[\pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr)\]Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior.
- Type:
Trust-PCL temperature
- Return type:
torch.Tensor
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the Trust-PCL loss: \(\alpha^2 \cdot \mathcal{L}_{\text{RTB}}\).
The scaling by \(\alpha^2\) is the only difference from
RelativeTrajectoryBalanceGFlowNet.loss(). It ensures gradient magnitudes match the Trust-PCL formulation.- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- property v_soft_s0: torch.Tensor¶
\(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\).
This is the expected return under the optimal entropy-regularized policy, starting from \(s_0\):
\[V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right]\]The KL regularization term \(\alpha \log(\pi_{\text{ref}} / \pi_\phi)\) in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition.
Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals \(\alpha \log \int p_\\theta(x)\,r(x)\,dx\).
- Type:
Soft value function at the initial state
- Return type:
torch.Tensor