tutorials.examples.train_hypergrid_gafn

A version of GFlowNet training that implements Generative Augmented Flow Networks (GAFN, https://arxiv.org/abs/2210.03308). It is a variant of GFlowNet that introduces intrinsic rewards to the GFlowNet training. It relies on the Random Network Distillation (RND, https://arxiv.org/abs/1810.12894) to define intrinsic rewards.

Example usage: python train_hypergrid_gafn.py –ndim 2 –height 8 –rnd_reward_scale 0.005

Key features: - Implements GAFN training - Uses RND to define intrinsic rewards - Based on TB loss like the train_hypergrid_simple.py example

Attributes

parser

Classes

RND

Random Network Distillation (RND) module. It is a module that predicts the random

TBGAFN

Generative Augmented Flow Networks based on the Trajectory Balance loss.

Functions

main(args)

Module Contents

class tutorials.examples.train_hypergrid_gafn.RND(state_dim, preprocessor, reward_scale=0.1, loss_scale=0.1, hidden_dim=256, s_latent_dim=128)

Bases: torch.nn.Module

Random Network Distillation (RND) module. It is a module that predicts the random target net from the state.

Parameters:
compute_intrinsic_reward(states)
Parameters:

states (gfn.states.States)

Return type:

torch.Tensor

compute_rnd_loss(states)
Parameters:

states (gfn.states.States)

Return type:

torch.Tensor

forward(states)
Parameters:

states (gfn.states.States)

Return type:

torch.Tensor

loss_scale = 0.1
predictor_net
preprocessor
random_target_net
reward_scale = 0.1
class tutorials.examples.train_hypergrid_gafn.TBGAFN(pf, pb, rnd, logZ=None, init_logZ=0.0, use_edge_ri=False, flow_estimator=None, log_reward_clip_min=-float('inf'))

Bases: gfn.gflownet.TBGFlowNet

Generative Augmented Flow Networks based on the Trajectory Balance loss.

Parameters:
flow_estimator = None
flow_parameters()
Return type:

list[torch.Tensor]

get_scores(trajectories, recalculate_all_logprobs=True, *, log_rewards=None)

Computes Trajectory Balance scores with intrinsic rewards for a batch of trajectories.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • log_rewards (torch.Tensor | None) – Optional override for the trajectories’ log rewards. If None, defaults to trajectories.log_rewards.

Returns:

A tensor of shape (batch_size,) containing the scores for each trajectory.

Return type:

torch.Tensor

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the trajectory balance loss.

The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).

  • trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed trajectory balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

rnd
rnd_parameters()
Return type:

list[torch.Tensor]

use_edge_ri = False
tutorials.examples.train_hypergrid_gafn.main(args)
tutorials.examples.train_hypergrid_gafn.parser