gfn.gflownet ============ .. py:module:: gfn.gflownet Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/gfn/gflownet/base/index /autoapi/gfn/gflownet/detailed_balance/index /autoapi/gfn/gflownet/flow_matching/index /autoapi/gfn/gflownet/losses/index /autoapi/gfn/gflownet/mle/index /autoapi/gfn/gflownet/sub_trajectory_balance/index /autoapi/gfn/gflownet/trajectory_balance/index Classes ------- .. autoapisummary:: gfn.gflownet.DBGFlowNet gfn.gflownet.FMGFlowNet gfn.gflownet.GFlowNet gfn.gflownet.HalfSquaredLoss gfn.gflownet.LinexLoss gfn.gflownet.LogPartitionVarianceGFlowNet gfn.gflownet.ModifiedDBGFlowNet gfn.gflownet.PFBasedGFlowNet gfn.gflownet.RegressionLoss gfn.gflownet.RelativeLogPartitionVarianceGFlowNet gfn.gflownet.RelativeTBBase gfn.gflownet.RelativeTrajectoryBalanceGFlowNet gfn.gflownet.ShiftedCoshLoss gfn.gflownet.SquaredLoss gfn.gflownet.SubTBGFlowNet gfn.gflownet.TBGFlowNet gfn.gflownet.TrajectoryBasedGFlowNet gfn.gflownet.TrustPCLGFlowNet Package Contents ---------------- .. py:class:: DBGFlowNet(pf, pb, logF, forward_looking = False, constant_pb = False, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Transitions`\ ] GFlowNet for the Detailed Balance loss. Corresponds to $\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3$, where $\mathcal{O}_1$ is the set of functions from the internal states (no $s_f$) to $\mathbb{R}^+$ (which we parametrize with logs, to avoid the non-negativity constraint), and $\mathcal{O}_2$ is the set of forward probability functions consistent with the DAG. $\mathcal{O}_3$ is the set of backward probability functions consistent with the DAG, or a singleton thereof, if `self.pb` is a fixed `DiscretePBEstimator`. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator. .. attribute:: logF A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. .. attribute:: forward_looking Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: forward_looking :value: False .. py:method:: get_pfs_and_pbs(transitions, recalculate_all_logprobs = True) Evaluates forward and backward logprobs for each transition in the batch. More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in the batch. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning. - Else (transitions have none of them), re-evaluate the logprobs using the current self.pf - this is usually for off-policy learning with replay buffer. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition. .. py:method:: get_scores(env, transitions, recalculate_all_logprobs = True, *, log_rewards = None) Calculates the scores for a batch of transitions. The scores for each transition are defined as: $\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)$. :param env: The environment where the transitions are sampled from. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param log_rewards: Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). **Not supported when** ``forward_looking=True``: raises ``ValueError`` in that case because the forward-looking objective still calls ``env.log_reward()`` for intermediate state adjustments, so custom rewards cannot fully replace environment rewards. :returns: A tensor of shape (n_transitions,) representing the scores for each transition. .. py:attribute:: logF .. py:method:: logF_named_parameters() Returns a dictionary of named parameters containing 'logF' in their name. :returns: A dictionary of named parameters containing 'logF' in their name. .. py:method:: logF_parameters() Returns a list of parameters containing 'logF' in their name. :returns: A list of parameters containing 'logF' in their name. .. py:method:: loss(env, transitions, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the detailed balance loss. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). :param env: The environment where the transitions are sampled from. :param transitions: The Transitions object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). Run with self.debug=False for improved performance. :param log_rewards: Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards. :returns: The computed detailed balance loss as a tensor. The shape depends on the reduction method. .. py:method:: to_training_samples(trajectories) Converts trajectories to transitions for detailed balance loss. :param trajectories: The Trajectories object to convert. :returns: A Transitions object containing all transitions from the trajectories. .. py:class:: FMGFlowNet(logF, alpha = 1.0, debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.GFlowNet`\ [\ :py:obj:`gfn.containers.StatesContainer`\ [\ :py:obj:`gfn.states.DiscreteStates`\ ]\ ] GFlowNet for the Flow Matching loss with an edge flow estimator. $\mathcal{O}_{edge}$ is the set of functions from the non-terminating edges to $\mathbb{R}^+$. Which is equivalent to the set of functions from the internal nodes (i.e. without $s_f$) to $(\mathbb{R})^{n_actions}$, without the exit action (No need for positivity if we parametrize log-flows). The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). .. attribute:: logF A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). .. attribute:: alpha A scalar weight for the reward matching loss. Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface. .. py:attribute:: alpha :value: 1.0 .. py:method:: flow_matching_loss(env, states, reduction = 'mean') Computes the flow matching loss for the (non-initial) states. The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include $s_0$. The batch shape should be `(n_states,)`. As of now, only discrete environments are handled. :param env: The discrete environment where the states are sampled from. :param states: The DiscreteStates object to evaluate (should not include $s_0$). :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed flow matching loss as a tensor. The shape depends on the reduction method. .. py:attribute:: logF .. py:method:: loss(env, states_container, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the flow matching loss for a batch of states. The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately. :param env: The discrete environment where the states are sampled from. :param states_container: The StatesContainer object containing both intermediary and terminating states. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs (unused for FM). :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed flow matching loss as a tensor. The shape depends on the reduction method. .. py:method:: reward_matching_loss(env, terminating_states, log_rewards, reduction = 'mean') Computes the reward matching loss for the terminating states. :param env: The discrete environment where the states are sampled from (unused). :param terminating_states: The DiscreteStates object containing terminating states. :param conditions: Optional conditions tensor for conditional environments. :param log_rewards: The log rewards for the terminating states. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed reward matching loss as a tensor. The shape depends on the reduction method. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) Samples trajectories using the edge flow estimator. :param env: The discrete environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the log-probabilities of the actions. :param save_estimator_outputs: Whether to save the estimator outputs. :param \*\*policy_kwargs: Additional keyword arguments for the sampler. :returns: A Trajectories object containing the sampled trajectories. .. py:method:: to_training_samples(trajectories) Converts trajectories to a StatesContainer for flow matching loss. :param trajectories: The Trajectories object to convert. :returns: A StatesContainer object containing all states from the trajectories. .. py:class:: GFlowNet(debug = False, loss_fn = None) Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module`, :py:obj:`Generic`\ [\ :py:obj:`TrainingSampleType`\ ] Abstract base class for GFlowNets. A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). .. py:method:: assert_finite_gradients() Asserts that the gradients are finite. .. py:method:: assert_finite_parameters() Asserts that the parameters are finite. .. py:attribute:: debug :value: False .. py:attribute:: log_reward_clip_min .. py:method:: loss(env, training_objects, recalculate_all_logprobs = True) :abstractmethod: Computes the loss given the training objects. :param env: The environment where the training objects are sampled from. :param training_objects: The objects to compute the loss with. :param recalculate_all_logprobs: If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available. :returns: The computed loss as a tensor. .. py:attribute:: loss_fn .. py:method:: loss_from_trajectories(env, trajectories, recalculate_all_logprobs = True) Helper method to compute loss directly from trajectories. This method converts trajectories to the appropriate training samples and computes the loss with the correct arguments based on the type of GFlowNet subclass. :param env: The environment where the training objects are sampled from. :param trajectories: The trajectories to compute the loss with. :param recalculate_all_logprobs: If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available. :returns: The computed loss as a tensor. .. py:method:: sample_terminating_states(env, n) Rolls out the policy and returns the terminating states. :param env: The environment to sample terminating states from. :param n: Number of terminating states to sample. :returns: The sampled terminating states as a States object. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) :abstractmethod: Samples a specific number of complete trajectories from the environment. :param env: The environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the logprobs of the actions (useful for on-policy learning). :param save_estimator_outputs: Whether to save the estimator outputs (useful for off-policy learning with a tempered policy). :returns: A Trajectories object containing the sampled trajectories. .. py:method:: to_training_samples(trajectories) :abstractmethod: Converts trajectories to training samples. :param trajectories: The Trajectories object to convert. :returns: The training samples, type depends on the type of GFlowNet subclass. .. py:class:: HalfSquaredLoss Bases: :py:obj:`RegressionLoss` Half squared loss: :math:`g(t) = \tfrac{1}{2} t^2`. The :math:`\tfrac{1}{2}` factor ensures the gradient equals the residual itself: :math:`g'(t) = t` rather than :math:`2t`. This is the standard least-squares convention (minimizing :math:`\tfrac{1}{2}\|r\|^2` so the normal equations have no factor of 2), and matches the RTB formulation in Venkatraman et al. (2024). This is the default loss for :class:`RelativeTrajectoryBalanceGFlowNet` and :class:`RelativeLogPartitionVarianceGFlowNet`. .. py:method:: __call__(residuals) Apply the loss elementwise. :param residuals: Balance condition residuals (any shape). :returns: Non-negative tensor of the same shape. .. py:class:: LinexLoss(alpha = 1.0) Bases: :py:obj:`RegressionLoss` Linear-exponential (Linex) loss: :math:`g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)`. The ``alpha`` parameter controls the asymmetry: - ``alpha = 1``: corresponds to the **forward KL** divergence. **Zero-avoiding** (mass-covering / exploration-favoring): penalizes the learner for missing mass where the target has support, encouraging broader mode coverage at the cost of some spurious mass. - ``alpha = 0.5``: corresponds to the **alpha-divergence** with ``alpha = 0.5``. **Balanced**: neither purely zero-forcing nor zero-avoiding. - ``alpha < 0``: becomes zero-forcing (mode-seeking), similar to but distinct from squared loss. The :math:`1/\alpha^2` normalization ensures ``g''(0) = 1`` for all ``alpha``, matching the curvature of squared loss near zero. .. rubric:: References Hu et al. "Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks" (ICLR 2025, arXiv:2410.02596). The Linex loss originates from Bayesian decision theory: Varian (1975), Zellner (1986). .. py:method:: __call__(residuals) Apply the loss elementwise. :param residuals: Balance condition residuals (any shape). :returns: Non-negative tensor of the same shape. .. py:method:: __eq__(other) .. py:method:: __hash__() .. py:method:: __repr__() .. py:attribute:: alpha :value: 1.0 .. py:class:: LogPartitionVarianceGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Log Partition Variance loss. The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator. .. attribute:: constant_pb Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the log partition variance loss. The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446). :param env: The environment where the trajectories are sampled from (unused). :param trajectories: The Trajectories object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed log partition variance loss as a tensor. The shape depends on the reduction method. .. py:class:: ModifiedDBGFlowNet(pf, pb, constant_pb = False, debug = False) Bases: :py:obj:`gfn.gflownet.base.PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Transitions`\ ] The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. py:method:: get_scores(transitions, recalculate_all_logprobs = True) Calculates DAG-GFN-style modified detailed balance scores. Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning. - Else, re-evaluate the log_probs using the current self.pf - this is usually for off-policy learning with replay buffer. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tensor of shape (n_transitions,) containing the scores for each transition. .. py:method:: loss(env, transitions, recalculate_all_logprobs = True, reduction = 'mean') Computes the modified detailed balance loss. :param env: The environment where the transitions are sampled from (unused). :param transitions: The Transitions object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed modified detailed balance loss as a tensor. The shape depends on the reduction method. .. py:method:: to_training_samples(trajectories) Converts trajectories to transitions for modified detailed balance loss. :param trajectories: The Trajectories object to convert. :returns: A Transitions object containing all transitions from the trajectories. .. py:class:: PFBasedGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`GFlowNet`\ [\ :py:obj:`TrainingSampleType`\ ], :py:obj:`abc.ABC` A GFlowNet that uses forward (PF) and backward (PB) policy networks. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1). .. attribute:: constant_pb Whether to ignore the backward policy estimator. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: constant_pb :value: False .. py:attribute:: log_reward_clip_min .. py:attribute:: pb .. py:attribute:: pf .. py:method:: pf_pb_named_parameters() Returns a dictionary of named parameters containing 'pf' or 'pb' in their name. :returns: A dictionary of named parameters containing 'pf' or 'pb' in their name. .. py:method:: pf_pb_parameters() Returns a list of parameters containing 'pf' or 'pb' in their name. :returns: A list of parameters containing 'pf' or 'pb' in their name. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) Samples trajectories using the forward policy network. :param env: The environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the logprobs of the actions. :param save_estimator_outputs: Whether to save the estimator outputs. :param \*\*policy_kwargs: Additional keyword arguments for the sampler. :returns: A Trajectories object containing the sampled trajectories. .. py:class:: RegressionLoss Bases: :py:obj:`abc.ABC` Abstract base for regression losses on GFlowNet balance residuals. Subclasses implement ``__call__`` mapping a residual tensor to a non-negative loss tensor of the same shape. .. py:method:: __call__(residuals) :abstractmethod: Apply the loss elementwise. :param residuals: Balance condition residuals (any shape). :returns: Non-negative tensor of the same shape. .. py:method:: __eq__(other) .. py:method:: __hash__() .. py:method:: __repr__() .. py:class:: RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTBBase` RTB variant that eliminates the learned logZ via variance minimization. Analogous to how :class:`LogPartitionVarianceGFlowNet` relates to :class:`TBGFlowNet`, this class mean-centers the RTB residuals within each batch so that no explicit ``logZ`` parameter is needed. The loss minimizes .. math:: \operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr], which equals the RTB loss evaluated at the batch-optimal :math:`\log Z^* = -\overline{s}` (the negative batch mean of scores). .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the Relative LPV loss on a batch of trajectories. .. py:class:: RelativeTBBase(pf, prior_pf, *, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` Shared base for Relative Trajectory Balance variants. Manages the prior forward policy and ``beta`` scaling. Subclasses only need to implement :meth:`loss` (deciding how to handle ``logZ`` and reduction). .. py:method:: _compute_rtb_scores(env, trajectories, log_rewards = None, recalculate_all_logprobs = True) RTB residuals: ``log_pf_post - log_pf_prior - beta * log_rewards``. :param env: The environment (unused, kept for API consistency). :param trajectories: The Trajectories object to evaluate. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: Shape ``(N,)`` per-trajectory scores. .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) RTB residuals (without logZ): ``log_pf_post - log_pf_prior - beta * log_R``. This is the public interface to the RTB balance residuals, analogous to :meth:`TrajectoryBasedGFlowNet.get_scores` for standard TB. :returns: Shape ``(N,)`` per-trajectory scores. .. py:property:: prior_pf :type: gfn.estimators.Estimator The fixed prior forward policy (not registered as a submodule). .. py:class:: RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ = None, init_logZ = 0.0, beta = 1.0, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTBBase` GFlowNet for the Relative Trajectory Balance (RTB) loss. This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing .. math:: \left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2, where :math:`p_\theta` is a fixed prior process, :math:`p_\phi` is the learnable posterior, :math:`r` is a positive reward/constraint on the terminal state :math:`x_T`, and :math:`\log Z_\phi` is a learned scalar normalizer. .. py:attribute:: logZ .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the RTB loss on a batch of trajectories. .. py:class:: ShiftedCoshLoss Bases: :py:obj:`RegressionLoss` Shifted hyperbolic cosine: :math:`g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)`. This is the **only** loss in the family that is simultaneously **zero-forcing** (penalizes spurious mass) and **zero-avoiding** (penalizes missing modes). It is symmetric: ``g(t) = g(-t)``. Near ``t = 0`` it behaves like ``t^2`` (same curvature as squared loss), but for large ``|t|`` it grows exponentially, providing stronger gradients for poorly-matched trajectories. Hu et al. (ICLR 2025) found this loss generally outperforms squared error on convergence speed and mode coverage across HyperGrid, bit-sequence, and sEH molecule benchmarks. .. rubric:: References Hu et al. "Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks" (ICLR 2025, arXiv:2410.02596). .. py:method:: __call__(residuals) Apply the loss elementwise. :param residuals: Balance condition residuals (any shape). :returns: Non-negative tensor of the same shape. .. py:class:: SquaredLoss Bases: :py:obj:`RegressionLoss` Standard squared loss: :math:`g(t) = t^2`. Corresponds to the **reverse KL** divergence (Malkin et al. 2022). This is **zero-forcing** (mode-seeking): it penalizes the learner for placing probability mass where the target has none, but does not penalize missing modes. This can lead to mode collapse in multi-modal targets. This is the default loss for TB, DB, SubTB, LPV, and FM classes, reproducing the standard behavior from the literature. .. py:method:: __call__(residuals) Apply the loss elementwise. :param residuals: Balance condition residuals (any shape). :returns: Non-negative tensor of the same shape. .. py:class:: SubTBGFlowNet(pf, pb, logF, weighting = 'geometric_within', lamda = 0.9, log_reward_clip_min = -float('inf'), forward_looking = False, constant_pb = False, debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Sub-Trajectory Balance loss. An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: logF A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. .. attribute:: weighting The sub-trajectories weighting scheme. - "DB": Considers all one-step transitions of each trajectory in the batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss. - "ModifiedDB": Considers all one-step transitions of each trajectory in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss. - "TB": Considers only the full trajectory. Should be equivalent to TrajectoryBalance loss. - "equal_within": Each sub-trajectory of each trajectory is weighed equally within the trajectory. Then each trajectory is weighed equally within the batch. - "equal": Each sub-trajectory of each trajectory is weighed equally within the set of all sub-trajectories. - "geometric_within": Each sub-trajectory of each trajectory is weighed proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER. - "geometric": Each sub-trajectory of each trajectory is weighed proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories. .. attribute:: lamda Discount factor for longer trajectories (used in geometric weighting). .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. attribute:: forward_looking Whether to use the forward-looking GFN loss. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. py:method:: calculate_log_state_flows(env, trajectories, log_pf_trajectories) Calculates log flows of each state in the trajectories. :param env: The environment object. :param trajectories: The batch of trajectories. :param log_pf_trajectories: Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories. :returns: A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories. .. py:method:: calculate_masks(log_state_flows, trajectories) Calculates masks indicating sink and terminal states. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the log flows of the states. :param trajectories: The batch of trajectories. :returns: A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size). .. py:method:: calculate_preds(log_pf_traj_cum, log_state_flows, i) Calculates the predictions tensor for the current sub-trajectory length. :param log_pf_traj_cum: Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the estimated log flow of the states. :param i: The sub-trajectory length. :returns: The predictions tensor of shape (max_length + 1 - i, batch_size). .. py:method:: calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards = None) Calculates the targets tensor for the current sub-trajectory length. :param trajectories: The batch of trajectories. :param preds: Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length. :param log_pb_traj_cum: Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the estimated log flow of the states. :param is_terminal_mask: A mask of shape (max_length, batch_size) indicating whether the state is terminal. :param sink_states_mask: A mask of shape (max_length, batch_size) indicating whether the state is a sink state. :param i: The sub-trajectory length. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The targets tensor of shape (max_length + 1 - i, batch_size). .. py:method:: cumulative_logprobs(trajectories, log_p_trajectories) Calculates the cumulative logprobs for all trajectories. :param trajectories: The batch of trajectories. :param log_p_trajectories: Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories. :returns: A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory. .. py:attribute:: forward_looking :value: False .. py:method:: get_equal_contributions(trajectories) Calculates contributions for the 'equal' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_equal_within_contributions(trajectories) Calculates contributions for the 'equal_within' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_geometric_within_contributions(trajectories) Calculates contributions for the 'geometric_within' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_modified_db_contributions(trajectories) Calculates contributions for the 'ModifiedDB' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) Computes sub-trajectory balance scores for all submitted trajectories. :param trajectories: The batch of trajectories to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param env: The environment where the trajectories are sampled from. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: - scores: A list of tensors, each representing the scores of all sub-trajectories of length k, for k in [1, ..., max_length], where the score of a sub-trajectory $\tau_{n:n+k} = (s_n, ..., s_{n+k})$ is $\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})$. The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size). - flattening_masks: A list of tensors indicating what should be masked out from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist. :rtype: A tuple (scores, flattening_masks) .. py:method:: get_tb_contributions(trajectories) Calculates contributions for the 'TB' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:attribute:: lamda :value: 0.9 .. py:attribute:: logF .. py:method:: logF_named_parameters() Returns a dictionary of named parameters containing 'logF' in their name. :returns: A dictionary of named parameters containing 'logF' in their name. .. py:method:: logF_parameters() Returns a list of parameters containing 'logF' in their name. :returns: A list of parameters containing 'logF' in their name. .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the sub-trajectory balance loss. :param env: The environment where the trajectories are sampled from. :param trajectories: The batch of trajectories to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). Note: for geometric-based sub-trajectory weighting, 'mean' is not supported and is coerced to 'sum' (a warning is emitted when debug=True). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for ``forward_looking=True``, the state-flow computation may still depend on ``env.log_reward(...)``, so custom ``log_rewards`` do not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method. .. py:attribute:: weighting :value: 'geometric_within' .. py:class:: TBGFlowNet(pf, pb, logZ = None, init_logZ = 0.0, constant_pb = False, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Trajectory Balance loss. $\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3$, where $\mathcal{O}_1 = \mathbb{R}$ represents the possible values for logZ, and $\mathcal{O}_2$ is the set of forward probability functions consistent with the DAG. $\mathcal{O}_3$ is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator. See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: logZ A learnable parameter or a ScalarEstimator instance (for conditional GFNs). .. attribute:: constant_pb Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: logZ .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the trajectory balance loss. The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259). :param env: The environment where the trajectories are sampled from (unused). :param trajectories: The Trajectories object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed trajectory balance loss as a tensor. The shape depends on the reduction method. .. py:class:: TrajectoryBasedGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Trajectories`\ ], :py:obj:`abc.ABC` A GFlowNet that operates on complete trajectories. .. attribute:: pf The forward policy module. .. attribute:: pb The backward policy module, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:method:: get_pfs_and_pbs(trajectories, recalculate_all_logprobs = True) Evaluates forward and backward logprobs for each trajectory in the batch. More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in each trajectory in the batch. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories using the current self.pf. Otherwise, the following applies: - If trajectories have logprobs attribute, use them - this is usually for on-policy learning. - Elif trajectories have estimator_outputs attribute, transform them into logprobs - this is usually for off-policy learning with a tempered policy. - Else (trajectories have none of them), re-evaluate the logprobs using the current self.pf - this is usually for off-policy learning with replay buffer. :param trajectories: The Trajectories object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tuple of tensors of shape (max_length, batch_size) containing the log_pf and log_pb for each action in each trajectory. .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) Calculates scores for a batch of trajectories. The scores for each trajectory are defined as: $\log \left( \frac{P_F(\tau)}{P_B(\tau \mid x) R(x)} \right)$. :param trajectories: The Trajectories object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param env: The environment (unused in base TB, but required by some subclasses such as RTB and SubTB). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards from the trajectories. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: A tensor of shape (batch_size,) containing the scores for each trajectory. .. py:method:: logz_named_parameters() Returns named parameters containing 'logZ' in their name. Works for any subclass that registers a logZ parameter (e.g. :class:`TBGFlowNet`, :class:`RelativeTrajectoryBalanceGFlowNet`). Returns an empty dict for subclasses without logZ. .. py:method:: logz_parameters() Returns parameters containing 'logZ' in their name. Works for any subclass that registers a logZ parameter (e.g. :class:`TBGFlowNet`, :class:`RelativeTrajectoryBalanceGFlowNet`). Returns an empty list for subclasses without logZ. .. py:method:: to_training_samples(trajectories) Returns the input trajectories as training samples. :param trajectories: The Trajectories object to use as training samples. :returns: The same Trajectories object. .. py:method:: trajectory_log_probs_backward(trajectories) Evaluates backward logprobs only for each trajectory in the batch. .. py:method:: trajectory_log_probs_forward(trajectories, recalculate_all_logprobs = True) Evaluates forward logprobs only for each trajectory in the batch. .. py:class:: TrustPCLGFlowNet(policy, reference_policy, *, alpha = 1.0, init_v_soft_s0 = 0.0, logZ = None, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`RelativeTrajectoryBalanceGFlowNet` Trust-PCL view of Relative Trajectory Balance. Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an **RL-native interface** to the same algorithm, using reinforcement learning terminology. The equivalence (Proposition 3.1 of Deleu et al.): .. math:: \mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi) where :math:`\alpha = 1/\beta` is the Trust-PCL temperature. **Parameter correspondence:** +---------------------+------------------------------+---------------------------+ | Concept | RTB name | Trust-PCL name | +=====================+==============================+===========================+ | Temperature | ``beta`` | ``alpha = 1/beta`` | +---------------------+------------------------------+---------------------------+ | Learned scalar | ``logZ`` | ``v_soft_s0 = alpha*logZ``| +---------------------+------------------------------+---------------------------+ | Trainable model | ``pf`` (posterior) | ``policy`` | +---------------------+------------------------------+---------------------------+ | Fixed reference | ``prior_pf`` | ``reference_policy`` | +---------------------+------------------------------+---------------------------+ **Interpretation of the learned scalar:** In RTB, ``logZ`` estimates the log-partition function :math:`\log \int p_\theta(x)\,r(x)\,dx`. In Trust-PCL, the same quantity is the **soft value function** at the initial state: :math:`V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi`. This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation. **Why this class exists:** The underlying computation is identical to :class:`RelativeTrajectoryBalanceGFlowNet` (the loss is just scaled by :math:`\alpha^2`). This class exists to: 1. Provide an RL-native constructor (``policy``, ``reference_policy``, ``alpha``, ``init_v_soft_s0``) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL. 2. Expose :attr:`alpha` and :attr:`v_soft_s0` properties for interpretability and monitoring. 3. Serve as a pedagogical bridge between the GFlowNet and RL communities. .. rubric:: References Deleu et al. "Relative Trajectory Balance is equivalent to Trust-PCL" (2025, arXiv:2509.01632). Nachum et al. "Trust-PCL: An Off-Policy Trust Region Method for Continuous Control" (NeurIPS 2017, arXiv:1707.01891). Venkatraman et al. "Amortizing intractable inference in diffusion models for vision, language, and control" (NeurIPS 2024, arXiv:2405.20971). .. py:property:: alpha :type: torch.Tensor :math:`\alpha = 1/\beta`. Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies: .. math:: \pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr) Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior. :type: Trust-PCL temperature .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the Trust-PCL loss: :math:`\alpha^2 \cdot \mathcal{L}_{\text{RTB}}`. The scaling by :math:`\alpha^2` is the only difference from :meth:`RelativeTrajectoryBalanceGFlowNet.loss`. It ensures gradient magnitudes match the Trust-PCL formulation. .. py:property:: v_soft_s0 :type: torch.Tensor :math:`V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi`. This is the expected return under the optimal entropy-regularized policy, starting from :math:`s_0`: .. math:: V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right] The KL regularization term :math:`\alpha \log(\pi_{\text{ref}} / \pi_\phi)` in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition. Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals :math:`\alpha \log \int p_\\theta(x)\,r(x)\,dx`. :type: Soft value function at the initial state