tutorials.examples.train_hypergrid_gafn¶
A version of GFlowNet training that implements Generative Augmented Flow Networks (GAFN, https://arxiv.org/abs/2210.03308). It is a variant of GFlowNet that introduces intrinsic rewards to the GFlowNet training. It relies on the Random Network Distillation (RND, https://arxiv.org/abs/1810.12894) to define intrinsic rewards.
Example usage: python train_hypergrid_gafn.py –ndim 2 –height 8 –rnd_reward_scale 0.005
Key features: - Implements GAFN training - Uses RND to define intrinsic rewards - Based on TB loss like the train_hypergrid_simple.py example
Attributes¶
Classes¶
Random Network Distillation (RND) module. It is a module that predicts the random |
|
Generative Augmented Flow Networks based on the Trajectory Balance loss. |
Functions¶
|
Module Contents¶
- class tutorials.examples.train_hypergrid_gafn.RND(state_dim, preprocessor, reward_scale=0.1, loss_scale=0.1, hidden_dim=256, s_latent_dim=128)¶
Bases:
torch.nn.ModuleRandom Network Distillation (RND) module. It is a module that predicts the random target net from the state.
- Parameters:
state_dim (int)
preprocessor (gfn.preprocessors.Preprocessor)
reward_scale (float)
loss_scale (float)
hidden_dim (int)
s_latent_dim (int)
- compute_intrinsic_reward(states)¶
- Parameters:
states (gfn.states.States)
- Return type:
torch.Tensor
- compute_rnd_loss(states)¶
- Parameters:
states (gfn.states.States)
- Return type:
torch.Tensor
- forward(states)¶
- Parameters:
states (gfn.states.States)
- Return type:
torch.Tensor
- loss_scale = 0.1¶
- predictor_net¶
- preprocessor¶
- random_target_net¶
- reward_scale = 0.1¶
- class tutorials.examples.train_hypergrid_gafn.TBGAFN(pf, pb, rnd, logZ=None, init_logZ=0.0, use_edge_ri=False, flow_estimator=None, log_reward_clip_min=-float('inf'))¶
Bases:
gfn.gflownet.TBGFlowNetGenerative Augmented Flow Networks based on the Trajectory Balance loss.
- Parameters:
pf (gfn.estimators.Estimator) – The forward policy estimator.
rnd (RND)
logZ (torch.nn.Parameter | gfn.estimators.ScalarEstimator | None)
init_logZ (float)
use_edge_ri (bool)
flow_estimator (gfn.estimators.ScalarEstimator | None)
log_reward_clip_min (float)
- flow_estimator = None¶
- flow_parameters()¶
- Return type:
list[torch.Tensor]
- get_scores(trajectories, recalculate_all_logprobs=True, *, log_rewards=None)¶
Computes Trajectory Balance scores with intrinsic rewards for a batch of trajectories.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
log_rewards (torch.Tensor | None) – Optional override for the trajectories’ log rewards. If None, defaults to trajectories.log_rewards.
- Returns:
A tensor of shape (batch_size,) containing the scores for each trajectory.
- Return type:
torch.Tensor
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the trajectory balance loss.
The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259).
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from (unused).
trajectories (gfn.containers.Trajectories) – The Trajectories object to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed trajectory balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- rnd¶
- rnd_parameters()¶
- Return type:
list[torch.Tensor]
- use_edge_ri = False¶
- tutorials.examples.train_hypergrid_gafn.main(args)¶
- tutorials.examples.train_hypergrid_gafn.parser¶