gfn.gflownet.detailed_balance¶
Classes¶
GFlowNet for the Detailed Balance loss. |
|
The Modified Detailed Balance GFlowNet. |
Functions¶
|
Call an estimator with or without conditions, using the appropriate handler. |
|
Checks compatibility between states and actions in transitions. |
Module Contents¶
- class gfn.gflownet.detailed_balance.DBGFlowNet(pf, pb, logF, forward_looking=False, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]GFlowNet for the Detailed Balance loss.
Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
pb (gfn.estimators.Estimator | None)
logF (gfn.estimators.ScalarEstimator | gfn.estimators.ConditionalScalarEstimator)
forward_looking (bool)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator.
- logF¶
A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.
- forward_looking¶
Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- forward_looking = False¶
- get_pfs_and_pbs(transitions, recalculate_all_logprobs=True)¶
Evaluates forward and backward logprobs for each transition in the batch.
More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in the batch.
If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:
- If transitions have log_probs attribute, use them - this is usually for
on-policy learning.
- Else (transitions have none of them), re-evaluate the logprobs using
the current self.pf - this is usually for off-policy learning with replay buffer.
- Parameters:
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- get_scores(env, transitions, recalculate_all_logprobs=True, *, log_rewards=None)¶
Calculates the scores for a batch of transitions.
The scores for each transition are defined as: \(\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)\).
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from.
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025). Not supported when
forward_looking=True: raisesValueErrorin that case because the forward-looking objective still callsenv.log_reward()for intermediate state adjustments, so custom rewards cannot fully replace environment rewards.
- Returns:
A tensor of shape (n_transitions,) representing the scores for each transition.
- Return type:
torch.Tensor
- logF¶
- logF_named_parameters()¶
Returns a dictionary of named parameters containing ‘logF’ in their name.
- Returns:
A dictionary of named parameters containing ‘logF’ in their name.
- Return type:
dict[str, torch.Tensor]
- logF_parameters()¶
Returns a list of parameters containing ‘logF’ in their name.
- Returns:
A list of parameters containing ‘logF’ in their name.
- Return type:
list[torch.Tensor]
- loss(env, transitions, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the detailed balance loss.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from.
transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Run with self.debug=False for improved performance.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards.
- Returns:
The computed detailed balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- to_training_samples(trajectories)¶
Converts trajectories to transitions for detailed balance loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A Transitions object containing all transitions from the trajectories.
- Return type:
- class gfn.gflownet.detailed_balance.ModifiedDBGFlowNet(pf, pb, constant_pb=False, debug=False)¶
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Modified Detailed Balance GFlowNet.
Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details.
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
debug (bool)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- get_scores(transitions, recalculate_all_logprobs=True)¶
Calculates DAG-GFN-style modified detailed balance scores.
Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states.
If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:
- If transitions have log_probs attribute, use them - this is usually for
on-policy learning.
- Else, re-evaluate the log_probs using the current self.pf - this is usually
for off-policy learning with replay buffer.
- Parameters:
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
A tensor of shape (n_transitions,) containing the scores for each transition.
- Return type:
torch.Tensor
- loss(env, transitions, recalculate_all_logprobs=True, reduction='mean')¶
Computes the modified detailed balance loss.
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from (unused).
transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed modified detailed balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- to_training_samples(trajectories)¶
Converts trajectories to transitions for modified detailed balance loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A Transitions object containing all transitions from the trajectories.
- Return type:
- gfn.gflownet.detailed_balance._call_estimator_with_conditions(estimator, name, states, conditions)¶
Call an estimator with or without conditions, using the appropriate handler.
This centralises the repeated if/else pattern for condition-aware estimator calls. The exception handlers add diagnostic context (estimator name and type) if a TypeError is raised due to a conditions mismatch.
The function is deliberately thin (no branching beyond the conditions check) so that it does not introduce extra graph breaks for torch.compile.
- Parameters:
estimator (gfn.estimators.Estimator)
name (str)
states (gfn.states.States)
conditions (torch.Tensor | None)
- Return type:
torch.Tensor
- gfn.gflownet.detailed_balance.check_compatibility(states, actions, transitions)¶
Checks compatibility between states and actions in transitions.
- Parameters:
states (gfn.states.States) – The states in the transitions.
actions (gfn.actions.Actions) – The actions in the transitions.
transitions (gfn.containers.Transitions) – The transitions object.
- Raises:
TypeError – If transitions is not of type Transitions.
ValueError – If there is a mismatch between states and actions batch shapes.
- Return type:
None