gfn.gflownet.detailed_balance ============================= .. py:module:: gfn.gflownet.detailed_balance Classes ------- .. autoapisummary:: gfn.gflownet.detailed_balance.DBGFlowNet gfn.gflownet.detailed_balance.ModifiedDBGFlowNet Functions --------- .. autoapisummary:: gfn.gflownet.detailed_balance._call_estimator_with_conditions gfn.gflownet.detailed_balance.check_compatibility Module Contents --------------- .. py:class:: DBGFlowNet(pf, pb, logF, forward_looking = False, constant_pb = False, log_reward_clip_min = -float('inf'), debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Transitions`\ ] GFlowNet for the Detailed Balance loss. Corresponds to $\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3$, where $\mathcal{O}_1$ is the set of functions from the internal states (no $s_f$) to $\mathbb{R}^+$ (which we parametrize with logs, to avoid the non-negativity constraint), and $\mathcal{O}_2$ is the set of forward probability functions consistent with the DAG. $\mathcal{O}_3$ is the set of backward probability functions consistent with the DAG, or a singleton thereof, if `self.pb` is a fixed `DiscretePBEstimator`. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator. .. attribute:: logF A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. .. attribute:: forward_looking Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: forward_looking :value: False .. py:method:: get_pfs_and_pbs(transitions, recalculate_all_logprobs = True) Evaluates forward and backward logprobs for each transition in the batch. More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in the batch. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning. - Else (transitions have none of them), re-evaluate the logprobs using the current self.pf - this is usually for off-policy learning with replay buffer. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition. .. py:method:: get_scores(env, transitions, recalculate_all_logprobs = True, *, log_rewards = None) Calculates the scores for a batch of transitions. The scores for each transition are defined as: $\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)$. :param env: The environment where the transitions are sampled from. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param log_rewards: Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). **Not supported when** ``forward_looking=True``: raises ``ValueError`` in that case because the forward-looking objective still calls ``env.log_reward()`` for intermediate state adjustments, so custom rewards cannot fully replace environment rewards. :returns: A tensor of shape (n_transitions,) representing the scores for each transition. .. py:attribute:: logF .. py:method:: logF_named_parameters() Returns a dictionary of named parameters containing 'logF' in their name. :returns: A dictionary of named parameters containing 'logF' in their name. .. py:method:: logF_parameters() Returns a list of parameters containing 'logF' in their name. :returns: A list of parameters containing 'logF' in their name. .. py:method:: loss(env, transitions, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the detailed balance loss. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). :param env: The environment where the transitions are sampled from. :param transitions: The Transitions object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). Run with self.debug=False for improved performance. :param log_rewards: Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards. :returns: The computed detailed balance loss as a tensor. The shape depends on the reduction method. .. py:method:: to_training_samples(trajectories) Converts trajectories to transitions for detailed balance loss. :param trajectories: The Trajectories object to convert. :returns: A Transitions object containing all transitions from the trajectories. .. py:class:: ModifiedDBGFlowNet(pf, pb, constant_pb = False, debug = False) Bases: :py:obj:`gfn.gflownet.base.PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Transitions`\ ] The Modified Detailed Balance GFlowNet. Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case. .. py:method:: get_scores(transitions, recalculate_all_logprobs = True) Calculates DAG-GFN-style modified detailed balance scores. Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies: - If transitions have log_probs attribute, use them - this is usually for on-policy learning. - Else, re-evaluate the log_probs using the current self.pf - this is usually for off-policy learning with replay buffer. :param transitions: The Transitions object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tensor of shape (n_transitions,) containing the scores for each transition. .. py:method:: loss(env, transitions, recalculate_all_logprobs = True, reduction = 'mean') Computes the modified detailed balance loss. :param env: The environment where the transitions are sampled from (unused). :param transitions: The Transitions object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed modified detailed balance loss as a tensor. The shape depends on the reduction method. .. py:method:: to_training_samples(trajectories) Converts trajectories to transitions for modified detailed balance loss. :param trajectories: The Trajectories object to convert. :returns: A Transitions object containing all transitions from the trajectories. .. py:function:: _call_estimator_with_conditions(estimator, name, states, conditions) Call an estimator with or without conditions, using the appropriate handler. This centralises the repeated if/else pattern for condition-aware estimator calls. The exception handlers add diagnostic context (estimator name and type) if a TypeError is raised due to a conditions mismatch. The function is deliberately thin (no branching beyond the conditions check) so that it does not introduce extra graph breaks for torch.compile. .. py:function:: check_compatibility(states, actions, transitions) Checks compatibility between states and actions in transitions. :param states: The states in the transitions. :param actions: The actions in the transitions. :param transitions: The transitions object. :raises TypeError: If transitions is not of type Transitions. :raises ValueError: If there is a mismatch between states and actions batch shapes.