tutorials.examples.train_hypergrid¶
The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)
python train_hypergrid.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}
And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}
This script also provides a function get_exact_P_T that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization.
Attributes¶
Functions¶
|
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. |
|
Evaluates the exact terminating state distribution P_T for HyperGrid. |
|
Trains a GFlowNet on the Hypergrid Environment. |
|
|
|
Returns a FM GFlowNet. |
|
Returns a GFlowNet complete with the required estimators. |
|
Returns a LogStateFlowEstimator. |
|
Returns a pair of estimators for the forward and backward policies. |
Module Contents¶
- tutorials.examples.train_hypergrid._make_optimizer_for(gflownet, args)¶
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.
- Return type:
torch.optim.Optimizer
- tutorials.examples.train_hypergrid.get_exact_P_T(env, gflownet)¶
Evaluates the exact terminating state distribution P_T for HyperGrid.
For each state s’, the terminating state probability is computed as:
\[P_T(s') = u(s') P_F(s_f | s')\]where u(s’) satisfies the recursion:
\[u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s)\]with the base case u(s_0) = 1.
- Parameters:
env (gfn.gym.HyperGrid) – The HyperGrid environment
gflownet (gfn.gflownet.GFlowNet) – The GFlowNet model
- Returns:
The exact terminating state distribution as a tensor
- Return type:
torch.Tensor
- tutorials.examples.train_hypergrid.logger¶
- tutorials.examples.train_hypergrid.main(args)¶
Trains a GFlowNet on the Hypergrid Environment.
- Return type:
dict
- tutorials.examples.train_hypergrid.parser¶
- tutorials.examples.train_hypergrid.plot_results(env, gflownet, l1_distances, args)¶
- tutorials.examples.train_hypergrid.set_up_fm_gflownet(args, env, preprocessor)¶
Returns a FM GFlowNet.
- tutorials.examples.train_hypergrid.set_up_gflownet(args, env, preprocessor)¶
Returns a GFlowNet complete with the required estimators.
- tutorials.examples.train_hypergrid.set_up_logF_estimator(args, env, preprocessor, pf_module)¶
Returns a LogStateFlowEstimator.
- tutorials.examples.train_hypergrid.set_up_pb_pf_estimators(args, env, preprocessor)¶
Returns a pair of estimators for the forward and backward policies.