gfn.gflownet.base

Attributes

TrainingSampleType

Classes

GFlowNet

Abstract base class for GFlowNets.

PFBasedGFlowNet

A GFlowNet that uses forward (PF) and backward (PB) policy networks.

TrajectoryBasedGFlowNet

A GFlowNet that operates on complete trajectories.

Functions

loss_reduce(loss, method)

Utility function to handle loss aggregation strategies.

Module Contents

class gfn.gflownet.base.GFlowNet(debug=False, loss_fn=None)

Bases: abc.ABC, torch.nn.Module, Generic[TrainingSampleType]

Abstract base class for GFlowNets.

A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266).

Parameters:
assert_finite_gradients()

Asserts that the gradients are finite.

assert_finite_parameters()

Asserts that the parameters are finite.

debug = False
log_reward_clip_min
abstract loss(env, training_objects, recalculate_all_logprobs=True)

Computes the loss given the training objects.

Parameters:
  • env (gfn.env.Env) – The environment where the training objects are sampled from.

  • training_objects (Any) – The objects to compute the loss with.

  • recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.

Returns:

The computed loss as a tensor.

Return type:

torch.Tensor

loss_fn
loss_from_trajectories(env, trajectories, recalculate_all_logprobs=True)

Helper method to compute loss directly from trajectories.

This method converts trajectories to the appropriate training samples and computes the loss with the correct arguments based on the type of GFlowNet subclass.

Parameters:
  • env (gfn.env.Env) – The environment where the training objects are sampled from.

  • trajectories (gfn.containers.Trajectories) – The trajectories to compute the loss with.

  • recalculate_all_logprobs (bool) – If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available.

Returns:

The computed loss as a tensor.

Return type:

torch.Tensor

sample_terminating_states(env, n)

Rolls out the policy and returns the terminating states.

Parameters:
  • env (gfn.env.Env) – The environment to sample terminating states from.

  • n (int) – Number of terminating states to sample.

Returns:

The sampled terminating states as a States object.

Return type:

gfn.states.States

abstract sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples a specific number of complete trajectories from the environment.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the logprobs of the actions (useful for on-policy learning).

  • save_estimator_outputs (bool) – Whether to save the estimator outputs (useful for off-policy learning with a tempered policy).

  • policy_kwargs (Any)

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

abstract to_training_samples(trajectories)

Converts trajectories to training samples.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

The training samples, type depends on the type of GFlowNet subclass.

Return type:

TrainingSampleType

class gfn.gflownet.base.PFBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: GFlowNet[TrainingSampleType], abc.ABC

A GFlowNet that uses forward (PF) and backward (PB) policy networks.

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1).

constant_pb

Whether to ignore the backward policy estimator.

log_reward_clip_min

If finite, clips log rewards to this value.

constant_pb = False
log_reward_clip_min
pb
pf
pf_pb_named_parameters()

Returns a dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.

Returns:

A dictionary of named parameters containing ‘pf’ or ‘pb’ in their name.

Return type:

dict[str, torch.Tensor]

pf_pb_parameters()

Returns a list of parameters containing ‘pf’ or ‘pb’ in their name.

Returns:

A list of parameters containing ‘pf’ or ‘pb’ in their name.

Return type:

list[torch.Tensor]

sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples trajectories using the forward policy network.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the logprobs of the actions.

  • save_estimator_outputs (bool) – Whether to save the estimator outputs.

  • **policy_kwargs (Any) – Additional keyword arguments for the sampler.

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

gfn.gflownet.base.TrainingSampleType
class gfn.gflownet.base.TrajectoryBasedGFlowNet(pf, pb, constant_pb=False, log_reward_clip_min=float('-inf'), debug=False, loss_fn=None)

Bases: PFBasedGFlowNet[gfn.containers.Trajectories], abc.ABC

A GFlowNet that operates on complete trajectories.

Parameters:
pf

The forward policy module.

pb

The backward policy module, or None if the gflownet DAG is a tree, and pb is therefore always 1.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

log_reward_clip_min

If finite, clips log rewards to this value.

get_pfs_and_pbs(trajectories, recalculate_all_logprobs=True)

Evaluates forward and backward logprobs for each trajectory in the batch.

More specifically, it evaluates \(\log P_F(s' \mid s)\) and \(\log P_B(s \mid s')\) for each transition in each trajectory in the batch.

If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories using the current self.pf. Otherwise, the following applies:

  • If trajectories have logprobs attribute, use them - this is usually for

    on-policy learning.

  • Elif trajectories have estimator_outputs attribute, transform them into

    logprobs - this is usually for off-policy learning with a tempered policy.

  • Else (trajectories have none of them), re-evaluate the logprobs using

    the current self.pf - this is usually for off-policy learning with replay buffer.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

Returns:

A tuple of tensors of shape (max_length, batch_size) containing the log_pf and log_pb for each action in each trajectory.

Return type:

Tuple[torch.Tensor, torch.Tensor]

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

Calculates scores for a batch of trajectories.

The scores for each trajectory are defined as: \(\log \left( \frac{P_F(\tau)}{P_B(\tau \mid x) R(x)} \right)\).

Parameters:
  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • env (gfn.env.Env | None) – The environment (unused in base TB, but required by some subclasses such as RTB and SubTB).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards from the trajectories. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

A tensor of shape (batch_size,) containing the scores for each trajectory.

Return type:

torch.Tensor

logz_named_parameters()

Returns named parameters containing ‘logZ’ in their name.

Works for any subclass that registers a logZ parameter (e.g. TBGFlowNet, RelativeTrajectoryBalanceGFlowNet). Returns an empty dict for subclasses without logZ.

Return type:

dict[str, torch.Tensor]

logz_parameters()

Returns parameters containing ‘logZ’ in their name.

Works for any subclass that registers a logZ parameter (e.g. TBGFlowNet, RelativeTrajectoryBalanceGFlowNet). Returns an empty list for subclasses without logZ.

Return type:

list[torch.Tensor]

to_training_samples(trajectories)

Returns the input trajectories as training samples.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to use as training samples.

Returns:

The same Trajectories object.

Return type:

gfn.containers.Trajectories

trajectory_log_probs_backward(trajectories)

Evaluates backward logprobs only for each trajectory in the batch.

Parameters:

trajectories (gfn.containers.Trajectories)

Return type:

torch.Tensor

trajectory_log_probs_forward(trajectories, recalculate_all_logprobs=True)

Evaluates forward logprobs only for each trajectory in the batch.

Parameters:
Return type:

torch.Tensor

gfn.gflownet.base.loss_reduce(loss, method)

Utility function to handle loss aggregation strategies.

Parameters:
  • loss (torch.Tensor) – The tensor to reduce.

  • method (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The reduced tensor.

Return type:

torch.Tensor