tutorials.examples.train_hypergrid ================================== .. py:module:: tutorials.examples.train_hypergrid .. autoapi-nested-parse:: The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) python train_hypergrid.py --ndim 4 --height 8 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} python train_hypergrid.py --ndim 2 --height 64 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} This script also provides a function `get_exact_P_T` that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization. Attributes ---------- .. autoapisummary:: tutorials.examples.train_hypergrid.logger tutorials.examples.train_hypergrid.parser Functions --------- .. autoapisummary:: tutorials.examples.train_hypergrid._make_optimizer_for tutorials.examples.train_hypergrid.get_exact_P_T tutorials.examples.train_hypergrid.main tutorials.examples.train_hypergrid.plot_results tutorials.examples.train_hypergrid.set_up_fm_gflownet tutorials.examples.train_hypergrid.set_up_gflownet tutorials.examples.train_hypergrid.set_up_logF_estimator tutorials.examples.train_hypergrid.set_up_pb_pf_estimators Module Contents --------------- .. py:function:: _make_optimizer_for(gflownet, args) Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. .. py:function:: get_exact_P_T(env, gflownet) Evaluates the exact terminating state distribution P_T for HyperGrid. For each state s', the terminating state probability is computed as: .. math:: P_T(s') = u(s') P_F(s_f | s') where u(s') satisfies the recursion: .. math:: u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s) with the base case u(s_0) = 1. :param env: The HyperGrid environment :param gflownet: The GFlowNet model :returns: The exact terminating state distribution as a tensor .. py:data:: logger .. py:function:: main(args) Trains a GFlowNet on the Hypergrid Environment. .. py:data:: parser .. py:function:: plot_results(env, gflownet, l1_distances, args) .. py:function:: set_up_fm_gflownet(args, env, preprocessor) Returns a FM GFlowNet. .. py:function:: set_up_gflownet(args, env, preprocessor) Returns a GFlowNet complete with the required estimators. .. py:function:: set_up_logF_estimator(args, env, preprocessor, pf_module) Returns a LogStateFlowEstimator. .. py:function:: set_up_pb_pf_estimators(args, env, preprocessor) Returns a pair of estimators for the forward and backward policies.