gfn.gflownet.detailed_balance

Classes

DBGFlowNet

GFlowNet for the Detailed Balance loss.

ModifiedDBGFlowNet

The Modified Detailed Balance GFlowNet.

Functions

_call_estimator_with_conditions(estimator, name, ...)

Call an estimator with or without conditions, using the appropriate handler.

check_compatibility(states, actions, transitions)

Checks compatibility between states and actions in transitions.

Module Contents

class gfn.gflownet.detailed_balance.DBGFlowNet(pf, pb, logF, forward_looking=False, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

GFlowNet for the Detailed Balance loss.

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator.

logF

A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.

forward_looking

Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

log_reward_clip_min

If finite, clips log rewards to this value.

forward_looking = False
get_pfs_and_pbs(transitions, recalculate_all_logprobs=True)

Evaluates forward and backward logprobs for each transition in the batch.

More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in the batch.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:

  • If transitions have log_probs attribute, use them - this is usually for

    on-policy learning.

  • Else (transitions have none of them), re-evaluate the logprobs using

    the current self.pf - this is usually for off-policy learning with replay buffer.

Parameters:
  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_scores(env, transitions, recalculate_all_logprobs=True, *, log_rewards=None)

Calculates the scores for a batch of transitions.

The scores for each transition are defined as: \(\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)\).

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from.

  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025). Not supported when forward_looking=True: raises ValueError in that case because the forward-looking objective still calls env.log_reward() for intermediate state adjustments, so custom rewards cannot fully replace environment rewards.

Returns:

A tensor of shape (n_transitions,) representing the scores for each transition.

Return type:

torch.Tensor

logF
logF_named_parameters()

Returns a dictionary of named parameters containing ‘logF’ in their name.

Returns:

A dictionary of named parameters containing ‘logF’ in their name.

Return type:

dict[str, torch.Tensor]

logF_parameters()

Returns a list of parameters containing ‘logF’ in their name.

Returns:

A list of parameters containing ‘logF’ in their name.

Return type:

list[torch.Tensor]

loss(env, transitions, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the detailed balance loss.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from.

  • transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Run with self.debug=False for improved performance.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards.

Returns:

The computed detailed balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

to_training_samples(trajectories)

Converts trajectories to transitions for detailed balance loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A Transitions object containing all transitions from the trajectories.

Return type:

gfn.containers.Transitions

class gfn.gflownet.detailed_balance.ModifiedDBGFlowNet(pf, pb, constant_pb=False, debug=False)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Modified Detailed Balance GFlowNet.

Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

get_scores(transitions, recalculate_all_logprobs=True)

Calculates DAG-GFN-style modified detailed balance scores.

Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:

  • If transitions have log_probs attribute, use them - this is usually for

    on-policy learning.

  • Else, re-evaluate the log_probs using the current self.pf - this is usually

    for off-policy learning with replay buffer.

Parameters:
  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tensor of shape (n_transitions,) containing the scores for each transition.

Return type:

torch.Tensor

loss(env, transitions, recalculate_all_logprobs=True, reduction='mean')

Computes the modified detailed balance loss.

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from (unused).

  • transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed modified detailed balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

to_training_samples(trajectories)

Converts trajectories to transitions for modified detailed balance loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A Transitions object containing all transitions from the trajectories.

Return type:

gfn.containers.Transitions

gfn.gflownet.detailed_balance._call_estimator_with_conditions(estimator, name, states, conditions)

Call an estimator with or without conditions, using the appropriate handler.

This centralises the repeated if/else pattern for condition-aware estimator calls. The exception handlers add diagnostic context (estimator name and type) if a TypeError is raised due to a conditions mismatch.

The function is deliberately thin (no branching beyond the conditions check) so that it does not introduce extra graph breaks for torch.compile.

Parameters:
Return type:

torch.Tensor

gfn.gflownet.detailed_balance.check_compatibility(states, actions, transitions)

Checks compatibility between states and actions in transitions.

Parameters:
Raises:
  • TypeError – If transitions is not of type Transitions.

  • ValueError – If there is a mismatch between states and actions batch shapes.

Return type:

None