tutorials.examples.train_hypergrid

The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)

python train_hypergrid.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}

And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}

This script also provides a function get_exact_P_T that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization.

Attributes

logger

parser

Functions

_make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

get_exact_P_T(env, gflownet)

Evaluates the exact terminating state distribution P_T for HyperGrid.

main(args)

Trains a GFlowNet on the Hypergrid Environment.

plot_results(env, gflownet, l1_distances, args)

set_up_fm_gflownet(args, env, preprocessor)

Returns a FM GFlowNet.

set_up_gflownet(args, env, preprocessor)

Returns a GFlowNet complete with the required estimators.

set_up_logF_estimator(args, env, preprocessor, pf_module)

Returns a LogStateFlowEstimator.

set_up_pb_pf_estimators(args, env, preprocessor)

Returns a pair of estimators for the forward and backward policies.

Module Contents

tutorials.examples.train_hypergrid._make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

Return type:

torch.optim.Optimizer

tutorials.examples.train_hypergrid.get_exact_P_T(env, gflownet)

Evaluates the exact terminating state distribution P_T for HyperGrid.

For each state s’, the terminating state probability is computed as:

\[P_T(s') = u(s') P_F(s_f | s')\]

where u(s’) satisfies the recursion:

\[u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s)\]

with the base case u(s_0) = 1.

Parameters:
Returns:

The exact terminating state distribution as a tensor

Return type:

torch.Tensor

tutorials.examples.train_hypergrid.logger
tutorials.examples.train_hypergrid.main(args)

Trains a GFlowNet on the Hypergrid Environment.

Return type:

dict

tutorials.examples.train_hypergrid.parser
tutorials.examples.train_hypergrid.plot_results(env, gflownet, l1_distances, args)
tutorials.examples.train_hypergrid.set_up_fm_gflownet(args, env, preprocessor)

Returns a FM GFlowNet.

tutorials.examples.train_hypergrid.set_up_gflownet(args, env, preprocessor)

Returns a GFlowNet complete with the required estimators.

tutorials.examples.train_hypergrid.set_up_logF_estimator(args, env, preprocessor, pf_module)

Returns a LogStateFlowEstimator.

tutorials.examples.train_hypergrid.set_up_pb_pf_estimators(args, env, preprocessor)

Returns a pair of estimators for the forward and backward policies.