gfn.gflownet

Submodules

Classes

DBGFlowNet

GFlowNet for the Detailed Balance loss.

FMGFlowNet

GFlowNet for the Flow Matching loss with an edge flow estimator.

GFlowNet

Abstract base class for GFlowNets.

HalfSquaredLoss

Half squared loss: \(g(t) = \tfrac{1}{2} t^2\).

LinexLoss

Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).

LogPartitionVarianceGFlowNet

GFlowNet for the Log Partition Variance loss.

ModifiedDBGFlowNet

The Modified Detailed Balance GFlowNet.

PFBasedGFlowNet

A GFlowNet that uses forward (PF) and backward (PB) policy networks.

RegressionLoss

Abstract base for regression losses on GFlowNet balance residuals.

RelativeLogPartitionVarianceGFlowNet

RTB variant that eliminates the learned logZ via variance minimization.

RelativeTBBase

Shared base for Relative Trajectory Balance variants.

RelativeTrajectoryBalanceGFlowNet

GFlowNet for the Relative Trajectory Balance (RTB) loss.

ShiftedCoshLoss

Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).

SquaredLoss

Standard squared loss: \(g(t) = t^2\).

SubTBGFlowNet

GFlowNet for the Sub-Trajectory Balance loss.

TBGFlowNet

GFlowNet for the Trajectory Balance loss.

TrajectoryBasedGFlowNet

A GFlowNet that operates on complete trajectories.

TrustPCLGFlowNet

Trust-PCL view of Relative Trajectory Balance.

Package Contents

class gfn.gflownet.DBGFlowNet(pf, pb, logF, forward_looking=False, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

GFlowNet for the Detailed Balance loss.

Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator.

logF

A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.

forward_looking

Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

log_reward_clip_min

If finite, clips log rewards to this value.

forward_looking = False
get_pfs_and_pbs(transitions, recalculate_all_logprobs=True)

Evaluates forward and backward logprobs for each transition in the batch.

More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in the batch.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:

  • If transitions have log_probs attribute, use them - this is usually for

    on-policy learning.

  • Else (transitions have none of them), re-evaluate the logprobs using

    the current self.pf - this is usually for off-policy learning with replay buffer.

Parameters:
  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_scores(env, transitions, recalculate_all_logprobs=True, *, log_rewards=None)

Calculates the scores for a batch of transitions.

The scores for each transition are defined as: \(\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)\).

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from.

  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025). Not supported when forward_looking=True: raises ValueError in that case because the forward-looking objective still calls env.log_reward() for intermediate state adjustments, so custom rewards cannot fully replace environment rewards.

Returns:

A tensor of shape (n_transitions,) representing the scores for each transition.

Return type:

torch.Tensor

logF
logF_named_parameters()

Returns a dictionary of named parameters containing ‘logF’ in their name.

Returns:

A dictionary of named parameters containing ‘logF’ in their name.

Return type:

dict[str, torch.Tensor]

logF_parameters()

Returns a list of parameters containing ‘logF’ in their name.

Returns:

A list of parameters containing ‘logF’ in their name.

Return type:

list[torch.Tensor]

loss(env, transitions, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the detailed balance loss.

The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from.

  • transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Run with self.debug=False for improved performance.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards.

Returns:

The computed detailed balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

to_training_samples(trajectories)

Converts trajectories to transitions for detailed balance loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A Transitions object containing all transitions from the trajectories.

Return type:

gfn.containers.Transitions

class gfn.gflownet.FMGFlowNet(logF, alpha=1.0, debug=False, loss_fn=None)

Bases: gfn.gflownet.base.GFlowNet[gfn.containers.StatesContainer[gfn.states.DiscreteStates]]

GFlowNet for the Flow Matching loss with an edge flow estimator.

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
logF

A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states).

alpha

A scalar weight for the reward matching loss.

Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface.

alpha = 1.0
flow_matching_loss(env, states, reduction='mean')

Computes the flow matching loss for the (non-initial) states.

The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.

  • states (gfn.states.DiscreteStates) – The DiscreteStates object to evaluate (should not include \(s_0\)).

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed flow matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

logF
loss(env, states_container, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the flow matching loss for a batch of states.

The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.

  • states_container (gfn.containers.StatesContainer[gfn.states.DiscreteStates]) – The StatesContainer object containing both intermediary and terminating states.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs (unused for FM).

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed flow matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

reward_matching_loss(env, terminating_states, log_rewards, reduction='mean')

Computes the reward matching loss for the terminating states.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from (unused).

  • terminating_states (gfn.states.DiscreteStates) – The DiscreteStates object containing terminating states.

  • conditions – Optional conditions tensor for conditional environments.

  • log_rewards (torch.Tensor | None) – The log rewards for the terminating states.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed reward matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples trajectories using the edge flow estimator.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the log-probabilities of the actions.

  • save_estimator_outputs (bool) – Whether to save the estimator outputs.

  • **policy_kwargs (Any) – Additional keyword arguments for the sampler.

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

to_training_samples(trajectories)

Converts trajectories to a StatesContainer for flow matching loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A StatesContainer object containing all states from the trajectories.

Return type:

gfn.containers.StatesContainer[gfn.states.DiscreteStates]

class gfn.gflownet.GFlowNet(debug=False, loss_fn=None)

Bases: abc.ABC, torch.nn.Module, Generic[TrainingSampleType]

Abstract base class for GFlowNets.

A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).

Parameters:
assert_finite_gradients()

Asserts that the gradients are finite.

assert_finite_parameters()

Asserts that the parameters are finite.

debug = False
log_reward_clip_min
abstract loss(env, training_objects, recalculate_all_logprobs=True)

Computes the loss given the training objects.

Parameters:
  • env (gfn.env.Env) – The environment where the training objects are sampled from.

  • training_objects (Any) – The objects to compute the loss with.

  • recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.

Returns:

The computed loss as a tensor.

Return type:

torch.Tensor

loss_fn
loss_from_trajectories(env, trajectories, recalculate_all_logprobs=True)

Helper method to compute loss directly from trajectories.

This method converts trajectories to the appropriate training samples and computes the loss with the correct arguments based on the type of GFlowNet subclass.

Parameters:
  • env (gfn.env.Env) – The environment where the training objects are sampled from.

  • trajectories (gfn.containers.Trajectories) – The trajectories to compute the loss with.

  • recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.

Returns:

The computed loss as a tensor.

Return type:

torch.Tensor

sample_terminating_states(env, n)

Rolls out the policy and returns the terminating states.

Parameters:
  • env (gfn.env.Env) – The environment to sample terminating states from.

  • n (int) – Number of terminating states to sample.

Returns:

The sampled terminating states as a States object.

Return type:

gfn.states.States

abstract sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples a specific number of complete trajectories from the environment.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the logprobs of the actions (useful for on-policy learning).

  • save_estimator_outputs (bool) – Whether to save the estimator outputs (useful for off-policy learning with a tempered policy).

  • policy_kwargs (Any)

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

abstract to_training_samples(trajectories)

Converts trajectories to training samples.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

The training samples, type depends on the type of GFlowNet subclass.

Return type:

TrainingSampleType

class gfn.gflownet.HalfSquaredLoss

Bases: RegressionLoss

Half squared loss: \(g(t) = \tfrac{1}{2} t^2\).

The \(\tfrac{1}{2}\) factor ensures the gradient equals the residual itself: \(g'(t) = t\) rather than \(2t\). This is the standard least-squares convention (minimizing \(\tfrac{1}{2}\|r\|^2\) so the normal equations have no factor of 2), and matches the RTB formulation in Venkatraman et al. (2024).

This is the default loss for RelativeTrajectoryBalanceGFlowNet and RelativeLogPartitionVarianceGFlowNet.

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

class gfn.gflownet.LinexLoss(alpha=1.0)

Bases: RegressionLoss

Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).

The alpha parameter controls the asymmetry:

  • alpha = 1: corresponds to the forward KL divergence. Zero-avoiding (mass-covering / exploration-favoring): penalizes the learner for missing mass where the target has support, encouraging broader mode coverage at the cost of some spurious mass.

  • alpha = 0.5: corresponds to the alpha-divergence with alpha = 0.5. Balanced: neither purely zero-forcing nor zero-avoiding.

  • alpha < 0: becomes zero-forcing (mode-seeking), similar to but distinct from squared loss.

The \(1/\alpha^2\) normalization ensures g''(0) = 1 for all alpha, matching the curvature of squared loss near zero.

References

Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).

The Linex loss originates from Bayesian decision theory: Varian (1975), Zellner (1986).

Parameters:

alpha (float)

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

__eq__(other)
Parameters:

other (object)

Return type:

bool

__hash__()
Return type:

int

__repr__()
Return type:

str

alpha = 1.0
class gfn.gflownet.LogPartitionVarianceGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Log Partition Variance loss.

The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator.

constant_pb

Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

log_reward_clip_min

If finite, clips log rewards to this value.

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the log partition variance loss.

The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed log partition variance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

class gfn.gflownet.ModifiedDBGFlowNet(pf, pb, constant_pb=False, debug=False)

Bases: gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]

The Modified Detailed Balance GFlowNet.

Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

get_scores(transitions, recalculate_all_logprobs=True)

Calculates DAG-GFN-style modified detailed balance scores.

Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:

  • If transitions have log_probs attribute, use them - this is usually for

    on-policy learning.

  • Else, re-evaluate the log_probs using the current self.pf - this is usually

    for off-policy learning with replay buffer.

Parameters:
  • transitions (gfn.containers.Transitions) – The Transitions object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tensor of shape (n_transitions,) containing the scores for each transition.

Return type:

torch.Tensor

loss(env, transitions, recalculate_all_logprobs=True, reduction='mean')

Computes the modified detailed balance loss.

Parameters:
  • env (gfn.env.Env) – The environment where the transitions are sampled from (unused).

  • transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed modified detailed balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

to_training_samples(trajectories)

Converts trajectories to transitions for modified detailed balance loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A Transitions object containing all transitions from the trajectories.

Return type:

gfn.containers.Transitions

class gfn.gflownet.PFBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: GFlowNet[TrainingSampleType], abc.ABC

A GFlowNet that uses forward (PF) and backward (PB) policy networks.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1).

constant_pb

Whether to ignore the backward policy estimator.

log_reward_clip_min

If finite, clips log rewards to this value.

constant_pb = False
log_reward_clip_min
pb
pf
pf_pb_named_parameters()

Returns a dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.

Returns:

A dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.

Return type:

dict[str, torch.Tensor]

pf_pb_parameters()

Returns a list of parameters containing ‘pf’ or ‘pb’ in their name.

Returns:

A list of parameters containing ‘pf’ or ‘pb’ in their name.

Return type:

list[torch.Tensor]

sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples trajectories using the forward policy network.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the logprobs of the actions.

  • save_estimator_outputs (bool) – Whether to save the estimator outputs.

  • **policy_kwargs (Any) – Additional keyword arguments for the sampler.

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

class gfn.gflownet.RegressionLoss

Bases: abc.ABC

Abstract base for regression losses on GFlowNet balance residuals.

Subclasses implement __call__ mapping a residual tensor to a non-negative loss tensor of the same shape.

abstract __call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

__eq__(other)
Parameters:

other (object)

Return type:

bool

__hash__()
Return type:

int

__repr__()
Return type:

str

class gfn.gflownet.RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTBBase

RTB variant that eliminates the learned logZ via variance minimization.

Analogous to how LogPartitionVarianceGFlowNet relates to TBGFlowNet, this class mean-centers the RTB residuals within each batch so that no explicit logZ parameter is needed.

The loss minimizes

\[\operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr],\]

which equals the RTB loss evaluated at the batch-optimal \(\log Z^* = -\overline{s}\) (the negative batch mean of scores).

Parameters:
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the Relative LPV loss on a batch of trajectories.

Parameters:
Return type:

torch.Tensor

class gfn.gflownet.RelativeTBBase(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

Shared base for Relative Trajectory Balance variants.

Manages the prior forward policy and beta scaling. Subclasses only need to implement loss() (deciding how to handle logZ and reduction).

Parameters:
_compute_rtb_scores(env, trajectories, log_rewards=None, recalculate_all_logprobs=True)

RTB residuals: log_pf_post - log_pf_prior - beta * log_rewards.

Parameters:
  • env (gfn.env.Env | None) – The environment (unused, kept for API consistency).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

Shape (N,) per-trajectory scores.

Return type:

torch.Tensor

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

RTB residuals (without logZ): log_pf_post - log_pf_prior - beta * log_R.

This is the public interface to the RTB balance residuals, analogous to TrajectoryBasedGFlowNet.get_scores() for standard TB.

Returns:

Shape (N,) per-trajectory scores.

Parameters:
Return type:

torch.Tensor

property prior_pf: gfn.estimators.Estimator

The fixed prior forward policy (not registered as a submodule).

Return type:

gfn.estimators.Estimator

class gfn.gflownet.RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ=None, init_logZ=0.0, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTBBase

GFlowNet for the Relative Trajectory Balance (RTB) loss.

This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing

\[\left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2,\]

where \(p_\theta\) is a fixed prior process, \(p_\phi\) is the learnable posterior, \(r\) is a positive reward/constraint on the terminal state \(x_T\), and \(\log Z_\phi\) is a learned scalar normalizer.

Parameters:
logZ
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the RTB loss on a batch of trajectories.

Parameters:
Return type:

torch.Tensor

class gfn.gflownet.ShiftedCoshLoss

Bases: RegressionLoss

Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).

This is the only loss in the family that is simultaneously zero-forcing (penalizes spurious mass) and zero-avoiding (penalizes missing modes). It is symmetric: g(t) = g(-t).

Near t = 0 it behaves like t^2 (same curvature as squared loss), but for large |t| it grows exponentially, providing stronger gradients for poorly-matched trajectories.

Hu et al. (ICLR 2025) found this loss generally outperforms squared error on convergence speed and mode coverage across HyperGrid, bit-sequence, and sEH molecule benchmarks.

References

Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

class gfn.gflownet.SquaredLoss

Bases: RegressionLoss

Standard squared loss: \(g(t) = t^2\).

Corresponds to the reverse KL divergence (Malkin et al. 2022). This is zero-forcing (mode-seeking): it penalizes the learner for placing probability mass where the target has none, but does not penalize missing modes. This can lead to mode collapse in multi-modal targets.

This is the default loss for TB, DB, SubTB, LPV, and FM classes, reproducing the standard behavior from the literature.

__call__(residuals)

Apply the loss elementwise.

Parameters:

residuals (torch.Tensor) – Balance condition residuals (any shape).

Returns:

Non-negative tensor of the same shape.

Return type:

torch.Tensor

class gfn.gflownet.SubTBGFlowNet(pf, pb, logF, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-float('inf'), forward_looking=False, constant_pb=False, debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Sub-Trajectory Balance loss.

An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

logF

A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.

weighting

The sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the

batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.

  • “ModifiedDB”: Considers all one-step transitions of each trajectory

    in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.

  • “TB”: Considers only the full trajectory. Should be equivalent to

    TrajectoryBalance loss.

  • “equal_within”: Each sub-trajectory of each trajectory is weighed

    equally within the trajectory. Then each trajectory is weighed equally within the batch.

  • “equal”: Each sub-trajectory of each trajectory is weighed equally

    within the set of all sub-trajectories.

  • “geometric_within”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.

  • “geometric”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.

lamda

Discount factor for longer trajectories (used in geometric weighting).

log_reward_clip_min

If finite, clips log rewards to this value.

forward_looking

Whether to use the forward-looking GFN loss.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

calculate_log_state_flows(env, trajectories, log_pf_trajectories)

Calculates log flows of each state in the trajectories.

Parameters:
  • env (gfn.env.Env) – The environment object.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • log_pf_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories.

Returns:

A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories.

Return type:

LogStateFlowsTensor

calculate_masks(log_state_flows, trajectories)

Calculates masks indicating sink and terminal states.

Parameters:
  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the log flows of the states.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size).

Return type:

Tuple[MaskTensor, MaskTensor]

calculate_preds(log_pf_traj_cum, log_state_flows, i)

Calculates the predictions tensor for the current sub-trajectory length.

Parameters:
  • log_pf_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory.

  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.

  • i (int) – The sub-trajectory length.

Returns:

The predictions tensor of shape (max_length + 1 - i, batch_size).

Return type:

PredictionsTensor

calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards=None)

Calculates the targets tensor for the current sub-trajectory length.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • preds (PredictionsTensor) – Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length.

  • log_pb_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory.

  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.

  • is_terminal_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is terminal.

  • sink_states_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is a sink state.

  • i (int) – The sub-trajectory length.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The targets tensor of shape (max_length + 1 - i, batch_size).

Return type:

TargetsTensor

cumulative_logprobs(trajectories, log_p_trajectories)

Calculates the cumulative logprobs for all trajectories.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • log_p_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories.

Returns:

A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory.

Return type:

CumulativeLogProbsTensor

forward_looking = False
get_equal_contributions(trajectories)

Calculates contributions for the ‘equal’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_equal_within_contributions(trajectories)

Calculates contributions for the ‘equal_within’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_geometric_within_contributions(trajectories)

Calculates contributions for the ‘geometric_within’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_modified_db_contributions(trajectories)

Calculates contributions for the ‘ModifiedDB’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

Computes sub-trajectory balance scores for all submitted trajectories.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • env (gfn.env.Env | None) – The environment where the trajectories are sampled from.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

  • scores: A list of tensors, each representing the scores of all

    sub-trajectories of length k, for k in [1, …, max_length], where the score of a sub-trajectory \(\tau_{n:n+k} = (s_n, ..., s_{n+k})\) is \(\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})\). The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size).

  • flattening_masks: A list of tensors indicating what should be masked out

    from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.

Return type:

A tuple (scores, flattening_masks)

get_tb_contributions(trajectories)

Calculates contributions for the ‘TB’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

lamda = 0.9
logF
logF_named_parameters()

Returns a dictionary of named parameters containing ‘logF’ in their name.

Returns:

A dictionary of named parameters containing ‘logF’ in their name.

Return type:

dict[str, torch.Tensor]

logF_parameters()

Returns a list of parameters containing ‘logF’ in their name.

Returns:

A list of parameters containing ‘logF’ in their name.

Return type:

list[torch.Tensor]

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the sub-trajectory balance loss.

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Note: for geometric-based sub-trajectory weighting, ‘mean’ is not supported and is coerced to ‘sum’ (a warning is emitted when debug=True).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for forward_looking=True, the state-flow computation may still depend on env.log_reward(...), so custom log_rewards do not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

weighting = 'geometric_within'
class gfn.gflownet.TBGFlowNet(pf, pb, logZ=None, init_logZ=0.0, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Trajectory Balance loss.

\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.

See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

logZ

A learnable parameter or a ScalarEstimator instance (for conditional GFNs).

constant_pb

Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.

log_reward_clip_min

If finite, clips log rewards to this value.

logZ
loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the trajectory balance loss.

The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed trajectory balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

class gfn.gflownet.TrajectoryBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: PFBasedGFlowNet[gfn.containers.Trajectories], abc.ABC

A GFlowNet that operates on complete trajectories.

Parameters:
pf

The forward policy module.

pb

The backward policy module, or None if the gflownet DAG is a tree, and pb is therefore always 1.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

log_reward_clip_min

If finite, clips log rewards to this value.

get_pfs_and_pbs(trajectories, recalculate_all_logprobs=True)

Evaluates forward and backward logprobs for each trajectory in the batch.

More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories using the current self.pf. Otherwise, the following applies:

  • If trajectories have logprobs attribute, use them - this is usually for

    on-policy learning.

  • Elif trajectories have estimator_outputs attribute, transform them into

    logprobs - this is usually for off-policy learning with a tempered policy.

  • Else (trajectories have none of them), re-evaluate the logprobs using

    the current self.pf - this is usually for off-policy learning with replay buffer.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tuple of tensors of shape (max_length, batch_size) containing the log_pf and log_pb for each action in each trajectory.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

Calculates scores for a batch of trajectories.

The scores for each trajectory are defined as: \(\log \left( \frac{P_F(\tau)}{P_B(\tau \mid x) R(x)} \right)\).

Parameters:
  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • env (gfn.env.Env | None) – The environment (unused in base TB, but required by some subclasses such as RTB and SubTB).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards from the trajectories. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

A tensor of shape (batch_size,) containing the scores for each trajectory.

Return type:

torch.Tensor

logz_named_parameters()

Returns named parameters containing ‘logZ’ in their name.

Works for any subclass that registers a logZ parameter (e.g. TBGFlowNet, RelativeTrajectoryBalanceGFlowNet). Returns an empty dict for subclasses without logZ.

Return type:

dict[str, torch.Tensor]

logz_parameters()

Returns parameters containing ‘logZ’ in their name.

Works for any subclass that registers a logZ parameter (e.g. TBGFlowNet, RelativeTrajectoryBalanceGFlowNet). Returns an empty list for subclasses without logZ.

Return type:

list[torch.Tensor]

to_training_samples(trajectories)

Returns the input trajectories as training samples.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to use as training samples.

Returns:

The same Trajectories object.

Return type:

gfn.containers.Trajectories

trajectory_log_probs_backward(trajectories)

Evaluates backward logprobs only for each trajectory in the batch.

Parameters:

trajectories (gfn.containers.Trajectories)

Return type:

torch.Tensor

trajectory_log_probs_forward(trajectories, recalculate_all_logprobs=True)

Evaluates forward logprobs only for each trajectory in the batch.

Parameters:
Return type:

torch.Tensor

class gfn.gflownet.TrustPCLGFlowNet(policy, reference_policy, *, alpha=1.0, init_v_soft_s0=0.0, logZ=None, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)

Bases: RelativeTrajectoryBalanceGFlowNet

Trust-PCL view of Relative Trajectory Balance.

Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an RL-native interface to the same algorithm, using reinforcement learning terminology.

The equivalence (Proposition 3.1 of Deleu et al.):

\[\mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi)\]

where \(\alpha = 1/\beta\) is the Trust-PCL temperature.

Parameter correspondence:

Interpretation of the learned scalar:

In RTB, logZ estimates the log-partition function \(\log \int p_\theta(x)\,r(x)\,dx\). In Trust-PCL, the same quantity is the soft value function at the initial state: \(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\). This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation.

Why this class exists:

The underlying computation is identical to RelativeTrajectoryBalanceGFlowNet (the loss is just scaled by \(\alpha^2\)). This class exists to:

  1. Provide an RL-native constructor (policy, reference_policy, alpha, init_v_soft_s0) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL.

  2. Expose alpha and v_soft_s0 properties for interpretability and monitoring.

  3. Serve as a pedagogical bridge between the GFlowNet and RL communities.

References

Deleu et al. “Relative Trajectory Balance is equivalent to Trust-PCL” (2025, arXiv:2509.01632).

Nachum et al. “Trust-PCL: An Off-Policy Trust Region Method for Continuous Control” (NeurIPS 2017, arXiv:1707.01891).

Venkatraman et al. “Amortizing intractable inference in diffusion models for vision, language, and control” (NeurIPS 2024, arXiv:2405.20971).

Parameters:
property alpha: torch.Tensor

\(\alpha = 1/\beta\).

Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies:

\[\pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr)\]

Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior.

Type:

Trust-PCL temperature

Return type:

torch.Tensor

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the Trust-PCL loss: \(\alpha^2 \cdot \mathcal{L}_{\text{RTB}}\).

The scaling by \(\alpha^2\) is the only difference from RelativeTrajectoryBalanceGFlowNet.loss(). It ensures gradient magnitudes match the Trust-PCL formulation.

Parameters:
Return type:

torch.Tensor

property v_soft_s0: torch.Tensor

\(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\).

This is the expected return under the optimal entropy-regularized policy, starting from \(s_0\):

\[V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right]\]

The KL regularization term \(\alpha \log(\pi_{\text{ref}} / \pi_\phi)\) in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition.

Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals \(\alpha \log \int p_\\theta(x)\,r(x)\,dx\).

Type:

Soft value function at the initial state

Return type:

torch.Tensor