tutorials.examples.train_hypergrid_gafn ======================================= .. py:module:: tutorials.examples.train_hypergrid_gafn .. autoapi-nested-parse:: A version of GFlowNet training that implements Generative Augmented Flow Networks (GAFN, https://arxiv.org/abs/2210.03308). It is a variant of GFlowNet that introduces intrinsic rewards to the GFlowNet training. It relies on the Random Network Distillation (RND, https://arxiv.org/abs/1810.12894) to define intrinsic rewards. Example usage: python train_hypergrid_gafn.py --ndim 2 --height 8 --rnd_reward_scale 0.005 Key features: - Implements GAFN training - Uses RND to define intrinsic rewards - Based on TB loss like the train_hypergrid_simple.py example Attributes ---------- .. autoapisummary:: tutorials.examples.train_hypergrid_gafn.parser Classes ------- .. autoapisummary:: tutorials.examples.train_hypergrid_gafn.RND tutorials.examples.train_hypergrid_gafn.TBGAFN Functions --------- .. autoapisummary:: tutorials.examples.train_hypergrid_gafn.main Module Contents --------------- .. py:class:: RND(state_dim, preprocessor, reward_scale = 0.1, loss_scale = 0.1, hidden_dim = 256, s_latent_dim = 128) Bases: :py:obj:`torch.nn.Module` Random Network Distillation (RND) module. It is a module that predicts the random target net from the state. .. py:method:: compute_intrinsic_reward(states) .. py:method:: compute_rnd_loss(states) .. py:method:: forward(states) .. py:attribute:: loss_scale :value: 0.1 .. py:attribute:: predictor_net .. py:attribute:: preprocessor .. py:attribute:: random_target_net .. py:attribute:: reward_scale :value: 0.1 .. py:class:: TBGAFN(pf, pb, rnd, logZ = None, init_logZ = 0.0, use_edge_ri = False, flow_estimator = None, log_reward_clip_min = -float('inf')) Bases: :py:obj:`gfn.gflownet.TBGFlowNet` Generative Augmented Flow Networks based on the Trajectory Balance loss. :param pf: The forward policy estimator. .. py:attribute:: flow_estimator :value: None .. py:method:: flow_parameters() .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, *, log_rewards = None) Computes Trajectory Balance scores with intrinsic rewards for a batch of trajectories. :param trajectories: The Trajectories object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param log_rewards: Optional override for the trajectories' log rewards. If None, defaults to `trajectories.log_rewards`. :returns: A tensor of shape (batch_size,) containing the scores for each trajectory. .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the trajectory balance loss. The trajectory balance loss is described in section 2.3 of [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259). :param env: The environment where the trajectories are sampled from (unused). :param trajectories: The Trajectories object to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed trajectory balance loss as a tensor. The shape depends on the reduction method. .. py:attribute:: rnd .. py:method:: rnd_parameters() .. py:attribute:: use_edge_ri :value: False .. py:function:: main(args) .. py:data:: parser