gfn.gflownet¶
Submodules¶
Classes¶
GFlowNet for the Detailed Balance loss. |
|
GFlowNet for the Flow Matching loss with an edge flow estimator. |
|
Abstract base class for GFlowNets. |
|
Half squared loss: \(g(t) = \tfrac{1}{2} t^2\). |
|
Linear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\). |
|
GFlowNet for the Log Partition Variance loss. |
|
The Modified Detailed Balance GFlowNet. |
|
A GFlowNet that uses forward (PF) and backward (PB) policy networks. |
|
Abstract base for regression losses on GFlowNet balance residuals. |
|
RTB variant that eliminates the learned logZ via variance minimization. |
|
Shared base for Relative Trajectory Balance variants. |
|
GFlowNet for the Relative Trajectory Balance (RTB) loss. |
|
Shifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\). |
|
Standard squared loss: \(g(t) = t^2\). |
|
GFlowNet for the Sub-Trajectory Balance loss. |
|
GFlowNet for the Trajectory Balance loss. |
|
A GFlowNet that operates on complete trajectories. |
|
Trust-PCL view of Relative Trajectory Balance. |
Package Contents¶
- class gfn.gflownet.DBGFlowNet(pf, pb, logF, forward_looking=False, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]GFlowNet for the Detailed Balance loss.
Corresponds to \(\mathcal{O}_{PF} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1\) is the set of functions from the internal states (no \(s_f\)) to \(\mathbb{R}^+\) (which we parametrize with logs, to avoid the non-negativity constraint), and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
pb (gfn.estimators.Estimator | None)
logF (gfn.estimators.ScalarEstimator | gfn.estimators.ConditionalScalarEstimator)
forward_looking (bool)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator.
- logF¶
A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.
- forward_looking¶
Whether to use the forward-looking GFN loss. When True, rewards must be defined over edges; this implementation treats the edge reward as the difference between the successor and current state rewards, so only valid if the environment follows that assumption.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- forward_looking = False¶
- get_pfs_and_pbs(transitions, recalculate_all_logprobs=True)¶
Evaluates forward and backward logprobs for each transition in the batch.
More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in the batch.
If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:
- If transitions have log_probs attribute, use them - this is usually for
on-policy learning.
- Else (transitions have none of them), re-evaluate the logprobs using
the current self.pf - this is usually for off-policy learning with replay buffer.
- Parameters:
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
A tuple of tensors of shape (n_transitions,) containing the log_pf and log_pb for each transition.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- get_scores(env, transitions, recalculate_all_logprobs=True, *, log_rewards=None)¶
Calculates the scores for a batch of transitions.
The scores for each transition are defined as: \(\log \left( \frac{F(s)P_F(s' \mid s)}{F(s') P_B(s \mid s')} \right)\).
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from.
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards from the transitions. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025). Not supported when
forward_looking=True: raisesValueErrorin that case because the forward-looking objective still callsenv.log_reward()for intermediate state adjustments, so custom rewards cannot fully replace environment rewards.
- Returns:
A tensor of shape (n_transitions,) representing the scores for each transition.
- Return type:
torch.Tensor
- logF¶
- logF_named_parameters()¶
Returns a dictionary of named parameters containing ‘logF’ in their name.
- Returns:
A dictionary of named parameters containing ‘logF’ in their name.
- Return type:
dict[str, torch.Tensor]
- logF_parameters()¶
Returns a list of parameters containing ‘logF’ in their name.
- Returns:
A list of parameters containing ‘logF’ in their name.
- Return type:
list[torch.Tensor]
- loss(env, transitions, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the detailed balance loss.
The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from.
transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Run with self.debug=False for improved performance.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_transitions,). When None, uses the environment rewards.
- Returns:
The computed detailed balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- to_training_samples(trajectories)¶
Converts trajectories to transitions for detailed balance loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A Transitions object containing all transitions from the trajectories.
- Return type:
- class gfn.gflownet.FMGFlowNet(logF, alpha=1.0, debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.GFlowNet[gfn.containers.StatesContainer[gfn.states.DiscreteStates]]GFlowNet for the Flow Matching loss with an edge flow estimator.
\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
alpha (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- logF¶
A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states).
- alpha¶
A scalar weight for the reward matching loss.
Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface.
- alpha = 1.0¶
- flow_matching_loss(env, states, reduction='mean')¶
Computes the flow matching loss for the (non-initial) states.
The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.
states (gfn.states.DiscreteStates) – The DiscreteStates object to evaluate (should not include \(s_0\)).
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed flow matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- logF¶
- loss(env, states_container, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the flow matching loss for a batch of states.
The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.
states_container (gfn.containers.StatesContainer[gfn.states.DiscreteStates]) – The StatesContainer object containing both intermediary and terminating states.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs (unused for FM).
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed flow matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- reward_matching_loss(env, terminating_states, log_rewards, reduction='mean')¶
Computes the reward matching loss for the terminating states.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from (unused).
terminating_states (gfn.states.DiscreteStates) – The DiscreteStates object containing terminating states.
conditions – Optional conditions tensor for conditional environments.
log_rewards (torch.Tensor | None) – The log rewards for the terminating states.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed reward matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)¶
Samples trajectories using the edge flow estimator.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment to sample trajectories from.
n (int) – Number of trajectories to sample.
conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.
save_logprobs (bool) – Whether to save the log-probabilities of the actions.
save_estimator_outputs (bool) – Whether to save the estimator outputs.
**policy_kwargs (Any) – Additional keyword arguments for the sampler.
- Returns:
A Trajectories object containing the sampled trajectories.
- Return type:
- to_training_samples(trajectories)¶
Converts trajectories to a StatesContainer for flow matching loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A StatesContainer object containing all states from the trajectories.
- Return type:
- class gfn.gflownet.GFlowNet(debug=False, loss_fn=None)¶
Bases:
abc.ABC,torch.nn.Module,Generic[TrainingSampleType]Abstract base class for GFlowNets.
A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).
- Parameters:
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- assert_finite_gradients()¶
Asserts that the gradients are finite.
- assert_finite_parameters()¶
Asserts that the parameters are finite.
- debug = False¶
- log_reward_clip_min¶
- abstract loss(env, training_objects, recalculate_all_logprobs=True)¶
Computes the loss given the training objects.
- Parameters:
env (gfn.env.Env) – The environment where the training objects are sampled from.
training_objects (Any) – The objects to compute the loss with.
recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.
- Returns:
The computed loss as a tensor.
- Return type:
torch.Tensor
- loss_fn¶
- loss_from_trajectories(env, trajectories, recalculate_all_logprobs=True)¶
Helper method to compute loss directly from trajectories.
This method converts trajectories to the appropriate training samples and computes the loss with the correct arguments based on the type of GFlowNet subclass.
- Parameters:
env (gfn.env.Env) – The environment where the training objects are sampled from.
trajectories (gfn.containers.Trajectories) – The trajectories to compute the loss with.
recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.
- Returns:
The computed loss as a tensor.
- Return type:
torch.Tensor
- sample_terminating_states(env, n)¶
Rolls out the policy and returns the terminating states.
- Parameters:
env (gfn.env.Env) – The environment to sample terminating states from.
n (int) – Number of terminating states to sample.
- Returns:
The sampled terminating states as a States object.
- Return type:
- abstract sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)¶
Samples a specific number of complete trajectories from the environment.
- Parameters:
env (gfn.env.Env) – The environment to sample trajectories from.
n (int) – Number of trajectories to sample.
conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.
save_logprobs (bool) – Whether to save the logprobs of the actions (useful for on-policy learning).
save_estimator_outputs (bool) – Whether to save the estimator outputs (useful for off-policy learning with a tempered policy).
policy_kwargs (Any)
- Returns:
A Trajectories object containing the sampled trajectories.
- Return type:
- abstract to_training_samples(trajectories)¶
Converts trajectories to training samples.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
The training samples, type depends on the type of GFlowNet subclass.
- Return type:
TrainingSampleType
- class gfn.gflownet.HalfSquaredLoss¶
Bases:
RegressionLossHalf squared loss: \(g(t) = \tfrac{1}{2} t^2\).
The \(\tfrac{1}{2}\) factor ensures the gradient equals the residual itself: \(g'(t) = t\) rather than \(2t\). This is the standard least-squares convention (minimizing \(\tfrac{1}{2}\|r\|^2\) so the normal equations have no factor of 2), and matches the RTB formulation in Venkatraman et al. (2024).
This is the default loss for
RelativeTrajectoryBalanceGFlowNetandRelativeLogPartitionVarianceGFlowNet.- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- class gfn.gflownet.LinexLoss(alpha=1.0)¶
Bases:
RegressionLossLinear-exponential (Linex) loss: \(g(t) = \frac{1}{\alpha^2}(e^{\alpha t} - \alpha t - 1)\).
The
alphaparameter controls the asymmetry:alpha = 1: corresponds to the forward KL divergence. Zero-avoiding (mass-covering / exploration-favoring): penalizes the learner for missing mass where the target has support, encouraging broader mode coverage at the cost of some spurious mass.alpha = 0.5: corresponds to the alpha-divergence withalpha = 0.5. Balanced: neither purely zero-forcing nor zero-avoiding.alpha < 0: becomes zero-forcing (mode-seeking), similar to but distinct from squared loss.
The \(1/\alpha^2\) normalization ensures
g''(0) = 1for allalpha, matching the curvature of squared loss near zero.References
Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).
The Linex loss originates from Bayesian decision theory: Varian (1975), Zellner (1986).
- Parameters:
alpha (float)
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- __eq__(other)¶
- Parameters:
other (object)
- Return type:
bool
- __hash__()¶
- Return type:
int
- __repr__()¶
- Return type:
str
- alpha = 1.0¶
- class gfn.gflownet.LogPartitionVarianceGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Log Partition Variance loss.
The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator.
- constant_pb¶
Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the log partition variance loss.
The log partition variance loss is described in section 3.2 of [Robust Scheduling with GFlowNets](https://arxiv.org/abs/2302.05446).
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).
trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed log partition variance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- class gfn.gflownet.ModifiedDBGFlowNet(pf, pb, constant_pb=False, debug=False)¶
Bases:
gfn.gflownet.base.PFBasedGFlowNet[gfn.containers.Transitions]The Modified Detailed Balance GFlowNet.
Only applicable to environments where all states are terminating. See section 3.2 of [Bayesian Structure Learning with Generative Flow Networks](https://arxiv.org/abs/2202.13903) for more details.
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
debug (bool)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- get_scores(transitions, recalculate_all_logprobs=True)¶
Calculates DAG-GFN-style modified detailed balance scores.
Note that this method is only applicable to environments where all states are terminating, i.e., the sink state is reachable from all states.
If recalculate_all_logprobs=True, we re-evaluate the logprobs of the transitions using the current self.pf. Otherwise, the following applies:
- If transitions have log_probs attribute, use them - this is usually for
on-policy learning.
- Else, re-evaluate the log_probs using the current self.pf - this is usually
for off-policy learning with replay buffer.
- Parameters:
transitions (gfn.containers.Transitions) – The Transitions object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
A tensor of shape (n_transitions,) containing the scores for each transition.
- Return type:
torch.Tensor
- loss(env, transitions, recalculate_all_logprobs=True, reduction='mean')¶
Computes the modified detailed balance loss.
- Parameters:
env (gfn.env.Env) – The environment where the transitions are sampled from (unused).
transitions (gfn.containers.Transitions) – The Transitions object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed modified detailed balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- to_training_samples(trajectories)¶
Converts trajectories to transitions for modified detailed balance loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A Transitions object containing all transitions from the trajectories.
- Return type:
- class gfn.gflownet.PFBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)¶
Bases:
GFlowNet[TrainingSampleType],abc.ABCA GFlowNet that uses forward (PF) and backward (PB) policy networks.
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1).
- constant_pb¶
Whether to ignore the backward policy estimator.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- constant_pb = False¶
- log_reward_clip_min¶
- pb¶
- pf¶
- pf_pb_named_parameters()¶
Returns a dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.
- Returns:
A dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.
- Return type:
dict[str, torch.Tensor]
- pf_pb_parameters()¶
Returns a list of parameters containing ‘pf’ or ‘pb’ in their name.
- Returns:
A list of parameters containing ‘pf’ or ‘pb’ in their name.
- Return type:
list[torch.Tensor]
- sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)¶
Samples trajectories using the forward policy network.
- Parameters:
env (gfn.env.Env) – The environment to sample trajectories from.
n (int) – Number of trajectories to sample.
conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.
save_logprobs (bool) – Whether to save the logprobs of the actions.
save_estimator_outputs (bool) – Whether to save the estimator outputs.
**policy_kwargs (Any) – Additional keyword arguments for the sampler.
- Returns:
A Trajectories object containing the sampled trajectories.
- Return type:
- class gfn.gflownet.RegressionLoss¶
Bases:
abc.ABCAbstract base for regression losses on GFlowNet balance residuals.
Subclasses implement
__call__mapping a residual tensor to a non-negative loss tensor of the same shape.- abstract __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- __eq__(other)¶
- Parameters:
other (object)
- Return type:
bool
- __hash__()¶
- Return type:
int
- __repr__()¶
- Return type:
str
- class gfn.gflownet.RelativeLogPartitionVarianceGFlowNet(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTBBaseRTB variant that eliminates the learned logZ via variance minimization.
Analogous to how
LogPartitionVarianceGFlowNetrelates toTBGFlowNet, this class mean-centers the RTB residuals within each batch so that no explicitlogZparameter is needed.The loss minimizes
\[\operatorname{Var}_{\tau}\!\bigl[\log p_\phi(\tau) - \log p_\theta(\tau) - \beta\,\log r(x_T)\bigr],\]which equals the RTB loss evaluated at the batch-optimal \(\log Z^* = -\overline{s}\) (the negative batch mean of scores).
- Parameters:
prior_pf (gfn.estimators.Estimator)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the Relative LPV loss on a batch of trajectories.
- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- class gfn.gflownet.RelativeTBBase(pf, prior_pf, *, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetShared base for Relative Trajectory Balance variants.
Manages the prior forward policy and
betascaling. Subclasses only need to implementloss()(deciding how to handlelogZand reduction).- Parameters:
prior_pf (gfn.estimators.Estimator)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- _compute_rtb_scores(env, trajectories, log_rewards=None, recalculate_all_logprobs=True)¶
RTB residuals:
log_pf_post - log_pf_prior - beta * log_rewards.- Parameters:
env (gfn.env.Env | None) – The environment (unused, kept for API consistency).
trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
Shape
(N,)per-trajectory scores.- Return type:
torch.Tensor
- get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)¶
RTB residuals (without logZ):
log_pf_post - log_pf_prior - beta * log_R.This is the public interface to the RTB balance residuals, analogous to
TrajectoryBasedGFlowNet.get_scores()for standard TB.- Returns:
Shape
(N,)per-trajectory scores.- Parameters:
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
env (gfn.env.Env | None)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- property prior_pf: gfn.estimators.Estimator¶
The fixed prior forward policy (not registered as a submodule).
- Return type:
- class gfn.gflownet.RelativeTrajectoryBalanceGFlowNet(pf, prior_pf, *, logZ=None, init_logZ=0.0, beta=1.0, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTBBaseGFlowNet for the Relative Trajectory Balance (RTB) loss.
This objective matches a posterior sampler to a prior diffusion (or other sequential) model by minimizing
\[\left(\log Z_\phi + \log p_\phi(\tau) - \log p_\theta(\tau) - \beta \log r(x_T)\right)^2,\]where \(p_\theta\) is a fixed prior process, \(p_\phi\) is the learnable posterior, \(r\) is a positive reward/constraint on the terminal state \(x_T\), and \(\log Z_\phi\) is a learned scalar normalizer.
- Parameters:
prior_pf (gfn.estimators.Estimator)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
init_logZ (float)
beta (float)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- logZ¶
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the RTB loss on a batch of trajectories.
- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- class gfn.gflownet.ShiftedCoshLoss¶
Bases:
RegressionLossShifted hyperbolic cosine: \(g(t) = e^t + e^{-t} - 2 = 2(\cosh(t) - 1)\).
This is the only loss in the family that is simultaneously zero-forcing (penalizes spurious mass) and zero-avoiding (penalizes missing modes). It is symmetric:
g(t) = g(-t).Near
t = 0it behaves liket^2(same curvature as squared loss), but for large|t|it grows exponentially, providing stronger gradients for poorly-matched trajectories.Hu et al. (ICLR 2025) found this loss generally outperforms squared error on convergence speed and mode coverage across HyperGrid, bit-sequence, and sEH molecule benchmarks.
References
Hu et al. “Beyond Squared Error: Exploring Loss Design for Enhanced Training of Generative Flow Networks” (ICLR 2025, arXiv:2410.02596).
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- class gfn.gflownet.SquaredLoss¶
Bases:
RegressionLossStandard squared loss: \(g(t) = t^2\).
Corresponds to the reverse KL divergence (Malkin et al. 2022). This is zero-forcing (mode-seeking): it penalizes the learner for placing probability mass where the target has none, but does not penalize missing modes. This can lead to mode collapse in multi-modal targets.
This is the default loss for TB, DB, SubTB, LPV, and FM classes, reproducing the standard behavior from the literature.
- __call__(residuals)¶
Apply the loss elementwise.
- Parameters:
residuals (torch.Tensor) – Balance condition residuals (any shape).
- Returns:
Non-negative tensor of the same shape.
- Return type:
torch.Tensor
- class gfn.gflownet.SubTBGFlowNet(pf, pb, logF, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-float('inf'), forward_looking=False, constant_pb=False, debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Sub-Trajectory Balance loss.
An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).
- Parameters:
pb (gfn.estimators.Estimator | None)
logF (gfn.estimators.ScalarEstimator | gfn.estimators.ConditionalScalarEstimator)
weighting (Literal['DB', 'ModifiedDB', 'TB', 'geometric', 'equal', 'geometric_within', 'equal_within'])
lamda (float)
log_reward_clip_min (float)
forward_looking (bool)
constant_pb (bool)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- logF¶
A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.
- weighting¶
The sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the
batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.
- “ModifiedDB”: Considers all one-step transitions of each trajectory
in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.
- “TB”: Considers only the full trajectory. Should be equivalent to
TrajectoryBalance loss.
- “equal_within”: Each sub-trajectory of each trajectory is weighed
equally within the trajectory. Then each trajectory is weighed equally within the batch.
- “equal”: Each sub-trajectory of each trajectory is weighed equally
within the set of all sub-trajectories.
- “geometric_within”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.
- “geometric”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.
- lamda¶
Discount factor for longer trajectories (used in geometric weighting).
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- forward_looking¶
Whether to use the forward-looking GFN loss.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.
- calculate_log_state_flows(env, trajectories, log_pf_trajectories)¶
Calculates log flows of each state in the trajectories.
- Parameters:
env (gfn.env.Env) – The environment object.
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
log_pf_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories.
- Returns:
A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories.
- Return type:
LogStateFlowsTensor
- calculate_masks(log_state_flows, trajectories)¶
Calculates masks indicating sink and terminal states.
- Parameters:
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the log flows of the states.
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size).
- Return type:
Tuple[MaskTensor, MaskTensor]
- calculate_preds(log_pf_traj_cum, log_state_flows, i)¶
Calculates the predictions tensor for the current sub-trajectory length.
- Parameters:
log_pf_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory.
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.
i (int) – The sub-trajectory length.
- Returns:
The predictions tensor of shape (max_length + 1 - i, batch_size).
- Return type:
PredictionsTensor
- calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards=None)¶
Calculates the targets tensor for the current sub-trajectory length.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
preds (PredictionsTensor) – Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length.
log_pb_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory.
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.
is_terminal_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is terminal.
sink_states_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is a sink state.
i (int) – The sub-trajectory length.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The targets tensor of shape (max_length + 1 - i, batch_size).
- Return type:
TargetsTensor
- cumulative_logprobs(trajectories, log_p_trajectories)¶
Calculates the cumulative logprobs for all trajectories.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
log_p_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories.
- Returns:
A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory.
- Return type:
CumulativeLogProbsTensor
- forward_looking = False¶
- get_equal_contributions(trajectories)¶
Calculates contributions for the ‘equal’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_equal_within_contributions(trajectories)¶
Calculates contributions for the ‘equal_within’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_geometric_within_contributions(trajectories)¶
Calculates contributions for the ‘geometric_within’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_modified_db_contributions(trajectories)¶
Calculates contributions for the ‘ModifiedDB’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)¶
Computes sub-trajectory balance scores for all submitted trajectories.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
env (gfn.env.Env | None) – The environment where the trajectories are sampled from.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
- scores: A list of tensors, each representing the scores of all
sub-trajectories of length k, for k in [1, …, max_length], where the score of a sub-trajectory \(\tau_{n:n+k} = (s_n, ..., s_{n+k})\) is \(\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})\). The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size).
- flattening_masks: A list of tensors indicating what should be masked out
from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Return type:
A tuple (scores, flattening_masks)
- get_tb_contributions(trajectories)¶
Calculates contributions for the ‘TB’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- lamda = 0.9¶
- logF¶
- logF_named_parameters()¶
Returns a dictionary of named parameters containing ‘logF’ in their name.
- Returns:
A dictionary of named parameters containing ‘logF’ in their name.
- Return type:
dict[str, torch.Tensor]
- logF_parameters()¶
Returns a list of parameters containing ‘logF’ in their name.
- Returns:
A list of parameters containing ‘logF’ in their name.
- Return type:
list[torch.Tensor]
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the sub-trajectory balance loss.
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from.
trajectories (gfn.containers.Trajectories) – The batch of trajectories to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Note: for geometric-based sub-trajectory weighting, ‘mean’ is not supported and is coerced to ‘sum’ (a warning is emitted when debug=True).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for
forward_looking=True, the state-flow computation may still depend onenv.log_reward(...), so customlog_rewardsdo not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- weighting = 'geometric_within'¶
- class gfn.gflownet.TBGFlowNet(pf, pb, logZ=None, init_logZ=0.0, constant_pb=False, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Trajectory Balance loss.
\(\mathcal{O}_{PFZ} = \mathcal{O}_1 \times \mathcal{O}_2 \times \mathcal{O}_3\), where \(\mathcal{O}_1 = \mathbb{R}\) represents the possible values for logZ, and \(\mathcal{O}_2\) is the set of forward probability functions consistent with the DAG. \(\mathcal{O}_3\) is the set of backward probability functions consistent with the DAG, or a singleton thereof, if self.pb is a fixed DiscretePBEstimator.
See [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) for more details.
- Parameters:
pb (gfn.estimators.Estimator | None)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
init_logZ (float)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- logZ¶
A learnable parameter or a ScalarEstimator instance (for conditional GFNs).
- constant_pb¶
Whether to ignore pb e.g., the GFlowNet DAG is a tree, and pb is therefore always 1. Must be set explicitly by user to ensure that pb is an Estimator except under this special case.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- logZ¶
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the trajectory balance loss.
The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).
trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed trajectory balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- class gfn.gflownet.TrajectoryBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)¶
Bases:
PFBasedGFlowNet[gfn.containers.Trajectories],abc.ABCA GFlowNet that operates on complete trajectories.
- Parameters:
pb (gfn.estimators.Estimator | None)
constant_pb (bool)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy module.
- pb¶
The backward policy module, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- get_pfs_and_pbs(trajectories, recalculate_all_logprobs=True)¶
Evaluates forward and backward logprobs for each trajectory in the batch.
More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.
If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories using the current self.pf. Otherwise, the following applies:
- If trajectories have logprobs attribute, use them - this is usually for
on-policy learning.
- Elif trajectories have estimator_outputs attribute, transform them into
logprobs - this is usually for off-policy learning with a tempered policy.
- Else (trajectories have none of them), re-evaluate the logprobs using
the current self.pf - this is usually for off-policy learning with replay buffer.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
- Returns:
A tuple of tensors of shape (max_length, batch_size) containing the log_pf and log_pb for each action in each trajectory.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)¶
Calculates scores for a batch of trajectories.
The scores for each trajectory are defined as: \(\log \left( \frac{P_F(\tau)}{P_B(\tau \mid x) R(x)} \right)\).
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
env (gfn.env.Env | None) – The environment (unused in base TB, but required by some subclasses such as RTB and SubTB).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards from the trajectories. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
A tensor of shape (batch_size,) containing the scores for each trajectory.
- Return type:
torch.Tensor
- logz_named_parameters()¶
Returns named parameters containing ‘logZ’ in their name.
Works for any subclass that registers a logZ parameter (e.g.
TBGFlowNet,RelativeTrajectoryBalanceGFlowNet). Returns an empty dict for subclasses without logZ.- Return type:
dict[str, torch.Tensor]
- logz_parameters()¶
Returns parameters containing ‘logZ’ in their name.
Works for any subclass that registers a logZ parameter (e.g.
TBGFlowNet,RelativeTrajectoryBalanceGFlowNet). Returns an empty list for subclasses without logZ.- Return type:
list[torch.Tensor]
- to_training_samples(trajectories)¶
Returns the input trajectories as training samples.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to use as training samples.
- Returns:
The same Trajectories object.
- Return type:
- trajectory_log_probs_backward(trajectories)¶
Evaluates backward logprobs only for each trajectory in the batch.
- Parameters:
trajectories (gfn.containers.Trajectories)
- Return type:
torch.Tensor
- trajectory_log_probs_forward(trajectories, recalculate_all_logprobs=True)¶
Evaluates forward logprobs only for each trajectory in the batch.
- Parameters:
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
- Return type:
torch.Tensor
- class gfn.gflownet.TrustPCLGFlowNet(policy, reference_policy, *, alpha=1.0, init_v_soft_s0=0.0, logZ=None, log_reward_clip_min=-float('inf'), debug=False, loss_fn=None)¶
Bases:
RelativeTrajectoryBalanceGFlowNetTrust-PCL view of Relative Trajectory Balance.
Deleu et al. (2025) proved that RTB is mathematically equivalent to Trust-PCL, an off-policy RL method with KL regularization toward a reference policy. This class provides an RL-native interface to the same algorithm, using reinforcement learning terminology.
The equivalence (Proposition 3.1 of Deleu et al.):
\[\mathcal{L}_{\text{Trust-PCL}}(\phi, \psi) = \alpha^2 \,\mathcal{L}_{\text{RTB}}(\phi, \psi)\]where \(\alpha = 1/\beta\) is the Trust-PCL temperature.
Parameter correspondence:
Interpretation of the learned scalar:
In RTB,
logZestimates the log-partition function \(\log \int p_\theta(x)\,r(x)\,dx\). In Trust-PCL, the same quantity is the soft value function at the initial state: \(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\). This connects GFlowNet training to entropy-regularized RL, where the soft value satisfies the soft Bellman equation.Why this class exists:
The underlying computation is identical to
RelativeTrajectoryBalanceGFlowNet(the loss is just scaled by \(\alpha^2\)). This class exists to:Provide an RL-native constructor (
policy,reference_policy,alpha,init_v_soft_s0) for researchers familiar with Trust-PCL / SAC / entropy-regularized RL.Expose
alphaandv_soft_s0properties for interpretability and monitoring.Serve as a pedagogical bridge between the GFlowNet and RL communities.
References
Deleu et al. “Relative Trajectory Balance is equivalent to Trust-PCL” (2025, arXiv:2509.01632).
Nachum et al. “Trust-PCL: An Off-Policy Trust Region Method for Continuous Control” (NeurIPS 2017, arXiv:1707.01891).
Venkatraman et al. “Amortizing intractable inference in diffusion models for vision, language, and control” (NeurIPS 2024, arXiv:2405.20971).
- Parameters:
policy (gfn.estimators.Estimator)
reference_policy (gfn.estimators.Estimator)
alpha (float)
init_v_soft_s0 (float)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
log_reward_clip_min (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- property alpha: torch.Tensor¶
\(\alpha = 1/\beta\).
Controls the strength of KL regularization toward the reference policy. At convergence, the learned policy satisfies:
\[\pi_\phi(a|s) \propto \pi_{\text{ref}}(a|s) \exp\!\bigl(Q^{\text{soft}}(s,a) / \alpha\bigr)\]Higher alpha → policy stays closer to the reference (more regularization). Lower alpha → policy deviates more toward reward-maximizing behavior.
- Type:
Trust-PCL temperature
- Return type:
torch.Tensor
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the Trust-PCL loss: \(\alpha^2 \cdot \mathcal{L}_{\text{RTB}}\).
The scaling by \(\alpha^2\) is the only difference from
RelativeTrajectoryBalanceGFlowNet.loss(). It ensures gradient magnitudes match the Trust-PCL formulation.- Parameters:
env (gfn.env.Env)
trajectories (gfn.containers.Trajectories)
recalculate_all_logprobs (bool)
reduction (str)
log_rewards (torch.Tensor | None)
- Return type:
torch.Tensor
- property v_soft_s0: torch.Tensor¶
\(V^{\text{soft}}_\psi(s_0) = \alpha \cdot \log Z_\psi\).
This is the expected return under the optimal entropy-regularized policy, starting from \(s_0\):
\[V^{\text{soft}}(s_0) = \mathbb{E}_{\pi_\phi}\!\left[ \sum_t r(s_t, a_t) + \alpha \sum_t \log \frac{\pi_{\text{ref}}(a_t|s_t)} {\pi_\phi(a_t|s_t)} \right]\]The KL regularization term \(\alpha \log(\pi_{\text{ref}} / \pi_\phi)\) in the sum emerges from the ratio of prior to posterior log-probabilities in the RTB balance condition.
Monitoring this value during training shows how the expected (regularized) return evolves. At convergence it equals \(\alpha \log \int p_\\theta(x)\,r(x)\,dx\).
- Type:
Soft value function at the initial state
- Return type:
torch.Tensor