gfn.gflownet.trajectory_balance =============================== .. py:module:: gfn.gflownet.trajectory_balance .. autoapi-nested-parse:: Implementations of the [Trajectory Balance loss](https://arxiv.org/abs/2201.13259) and the [Log Partition Variance loss](https://arxiv.org/abs/2302.05446). Classes ------- .. autoapisummary:: gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet gfn.gflownet.trajectory_balance.RelativeLogPartitionVarianceGFlowNet gfn.gflownet.trajectory_balance.RelativeTBBase gfn.gflownet.trajectory_balance.RelativeTrajectoryBalanceGFlowNet gfn.gflownet.trajectory_balance.TBGFlowNet gfn.gflownet.trajectory_balance.TrustPCLGFlowNet Module Contents --------------- .. py:class:: LogPartitionVarianceGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Log Partition Variance loss. The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator. .. attribute:: constant_pb Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the log partition variance loss. The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446). :param env: The environment where the trajectories are sampled from (unused). :param trajectories: The Trajectories object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed log partition variance loss as a tensor. The shape depends on the reduction method. .. py:class:: RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTBBase` RTB variant that eliminates the learned logZ via variance minimization. Analogous to how :class:`LogPartitionVarianceGFlowNet` relates to :class:`TBGFlowNet`, this class mean-centers the RTB residuals within each batch so that no explicit ``logZ`` parameter is needed. The loss minimizes .. math:: \operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr], which equals the RTB loss evaluated at the batch-optimal :math:`\log Z^* = -\overline{s}` (the negative batch mean of scores). .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the Relative LPV loss on a batch of trajectories. .. py:class:: RelativeTBBase(pf, prior_pf, *, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` Shared base for Relative Trajectory Balance variants. Manages the prior forward policy and ``beta`` scaling. Subclasses only need to implement :meth:`loss` (deciding how to handle ``logZ`` and reduction). .. py:method:: _compute_rtb_scores(env, trajectories, log_rewards = None, recalculate_all_logprobs = True) RTB residuals: ``log_pf_post - log_pf_prior - beta * log_rewards``. :param env: The environment (unused, kept for API consistency). :param trajectories: The Trajectories object to evaluate. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: Shape ``(N,)`` per-trajectory scores. .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) RTB residuals (without logZ): ``log_pf_post - log_pf_prior - beta * log_R``. This is the public interface to the RTB balance residuals, analogous to :meth:`TrajectoryBasedGFlowNet.get_scores` for standard TB. :returns: Shape ``(N,)`` per-trajectory scores. .. py:property:: prior_pf :type: gfn.estimators.Estimator The fixed prior forward policy (not registered as a submodule). .. py:class:: RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ = None, init_logZ = 0.0, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTBBase` GFlowNet for the Relative Trajectory Balance (RTB) loss. This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing .. math:: \left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2, where :math:`p_\theta` is a fixed prior process, :math:`p_\phi` is the learnable posterior, :math:`r` is a positive reward/constraint on the terminal state :math:`x_T`, and :math:`\log Z_\phi` is a learned scalar normalizer. .. py:attribute:: logZ .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the RTB loss on a batch of trajectories. .. py:class:: TBGFlowNet(pf, pb, logZ = None, init_logZ = 0.0, constant_pb = False, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Trajectory Balance loss. $\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3$, where $\mathcal{O}_1 = \mathbb{R}$ represents the possible values for logZ, and $\mathcal{O}_2$ is the set of forward probability functions consistent with the DAG. $\mathcal{O}_3$ is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator. See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: logZ A learnable parameter or a ScalarEstimator instance (for conditional GFNs). .. attribute:: constant_pb Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: logZ .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the trajectory balance loss. The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259). :param env: The environment where the trajectories are sampled from (unused). :param trajectories: The Trajectories object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed trajectory balance loss as a tensor. The shape depends on the reduction method. .. py:class:: TrustPCLGFlowNet(policy, reference_policy, *, alpha = 1.0, init_v_soft_s0 = 0.0, logZ = None, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTrajectoryBalanceGFlowNet` Trust-PCL view of Relative Trajectory Balance. Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an **RL-native interface** to the same algorithm, using reinforcement learning terminology. The equivalence (Proposition 3.1 of Deleu et al.): .. math:: \mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi) where :math:`\alpha = 1/\beta` is the Trust-PCL temperature. **Parameter correspondence:** +---------------------+------------------------------+---------------------------+ | Concept | RTB name | Trust-PCL name | +=====================+==============================+===========================+ | Temperature | ``beta`` | ``alpha = 1/beta`` | +---------------------+------------------------------+---------------------------+ | Learned scalar | ``logZ`` | ``v_soft_s0 = alpha*logZ``| +---------------------+------------------------------+---------------------------+ | Trainable model | ``pf`` (posterior) | ``policy`` | +---------------------+------------------------------+---------------------------+ | Fixed reference | ``prior_pf`` | ``reference_policy`` | +---------------------+------------------------------+---------------------------+ **Interpretation of the learned scalar:** In RTB, ``logZ`` estimates the log-partition function :math:`\log \int p_\theta(x)\,r(x)\,dx`. In Trust-PCL, the same quantity is the **soft value function** at the initial state: :math:`V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi`. This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation. **Why this class exists:** The underlying computation is identical to :class:`RelativeTrajectoryBalanceGFlowNet` (the loss is just scaled by :math:`\alpha^2`). This class exists to: 1. Provide an RL-native constructor (``policy``, ``reference_policy``, ``alpha``, ``init_v_soft_s0``) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL. 2. Expose :attr:`alpha` and :attr:`v_soft_s0` properties for interpretability and monitoring. 3. Serve as a pedagogical bridge between the GFlowNet and RL communities. .. rubric:: References Deleu et al. "Relative Trajectory Balance is equivalent to Trust-PCL" (2025, arXiv:2509.01632). Nachum et al. "Trust-PCL: An Off-Policy Trust Region Method for Continuous Control" (NeurIPS 2017, arXiv:1707.01891). Venkatraman et al. "Amortizing intractable inference in diffusion models for vision, language, and control" (NeurIPS 2024, arXiv:2405.20971). .. py:property:: alpha :type: torch.Tensor :math:`\alpha = 1/\beta`. Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies: .. math:: \pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr) Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior. :type: Trust-PCL temperature .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the Trust-PCL loss: :math:`\alpha^2 \cdot \mathcal{L}_{\text{RTB}}`. The scaling by :math:`\alpha^2` is the only difference from :meth:`RelativeTrajectoryBalanceGFlowNet.loss`. It ensures gradient magnitudes match the Trust-PCL formulation. .. py:property:: v_soft_s0 :type: torch.Tensor :math:`V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi`. This is the expected return under the optimal entropy-regularized policy, starting from :math:`s_0`: .. math:: V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right] The KL regularization term :math:`\alpha \log(\pi_{\text{ref}} / \pi_\phi)` in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition. Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals :math:`\alpha \log \int p_\\theta(x)\,r(x)\,dx`. :type: Soft value function at the initial state