tutorials.examples.train_hypergrid_ddp ====================================== .. py:module:: tutorials.examples.train_hypergrid_ddp .. autoapi-nested-parse:: The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) python train_hypergrid_ddp.py --ndim 4 --height 8 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} python train_hypergrid_ddp.py --ndim 2 --height 64 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid_ddp.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} This script uses DDP (DistributedDataParallel) for multi-GPU gradient-parallel training. Launch with torchrun: torchrun --nproc_per_node=4 train_hypergrid_ddp.py --loss TB --batch_size 64 Each GPU processes a portion of the batch, and gradients are synchronized via all-reduce after every backward pass. Attributes ---------- .. autoapisummary:: tutorials.examples.train_hypergrid_ddp.logger tutorials.examples.train_hypergrid_ddp.parser Functions --------- .. autoapisummary:: tutorials.examples.train_hypergrid_ddp._make_optimizer_for tutorials.examples.train_hypergrid_ddp.main tutorials.examples.train_hypergrid_ddp.set_up_fm_gflownet tutorials.examples.train_hypergrid_ddp.set_up_gflownet tutorials.examples.train_hypergrid_ddp.set_up_logF_estimator tutorials.examples.train_hypergrid_ddp.set_up_pb_pf_estimators Module Contents --------------- .. py:function:: _make_optimizer_for(gflownet, args) Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. .. py:data:: logger .. py:function:: main(args) Train a GFlowNet on the HyperGrid environment using DDP. High-level flow: 1. **DDP initialization** — detect rank/world-size from torchrun, MPI, or SLURM environment variables and create a ``DistributedContext``. 2. **Buffer ranks** — if ``--num_remote_buffers > 0``, the last *N* ranks are dedicated replay-buffer servers and never enter the training loop. 3. **Model setup** — build the GFlowNet, optimizer, and optional replay buffer on each training rank. 4. **Training loop** — each iteration: sample trajectories → compute loss → backward → all-reduce gradients across training ranks → optimizer step. 5. **Validation & logging** — periodically compute L1 distance to the true distribution and log metrics to WandB. 6. **Cleanup** — send termination signals to buffer ranks, barrier, return. :param args: Parsed CLI arguments (see ``__main__`` block below). :returns: A dict of final training metrics (loss, l1_dist, modes found, etc.). .. py:data:: parser .. py:function:: set_up_fm_gflownet(args, env, preprocessor) Build a Flow Matching GFlowNet. Flow Matching (FM) only requires a forward policy estimator and learns by matching incoming/outgoing flow at each state. :param args: CLI arguments (uses ``tabular``, ``hidden_dim``, ``n_hidden``). :param env: HyperGrid environment instance. :param preprocessor: State preprocessor (e.g. KHotPreprocessor). :returns: An FMGFlowNet ready for training. .. py:function:: set_up_gflownet(args, env, preprocessor) Build a GFlowNet for the requested loss function. Constructs the appropriate estimators and GFlowNet variant: - **FM** (Flow Matching): forward policy only. - **TB** (Trajectory Balance): forward + backward policies, learnable logZ. - **DB** (Detailed Balance): forward + backward policies + log-state-flow F. - **SubTB** (Sub-Trajectory Balance): like DB, with configurable sub-trajectory weighting. - **ZVar** (Log-Partition Variance): forward + backward policies. - **ModifiedDB**: a variant of DB with edge-based rewards. :param args: CLI arguments (``loss`` selects the variant; architecture flags are forwarded to the estimator builders). :param env: HyperGrid environment instance. :param preprocessor: State preprocessor. :returns: A GFlowNet instance, or ``None`` if the loss is unrecognized. .. py:function:: set_up_logF_estimator(args, env, preprocessor, pf_module) Build the log-state-flow estimator (used by DB and SubTB losses). When ``--tied`` is set, the MLP trunk is shared with the PF module to reduce the number of trainable parameters. :param args: CLI arguments (uses ``tabular``, ``tied``, ``hidden_dim``, ``n_hidden``). :param env: HyperGrid environment instance. :param preprocessor: State preprocessor. :param pf_module: The forward policy's underlying module (trunk may be reused). :returns: A ScalarEstimator for log F(s). .. py:function:: set_up_pb_pf_estimators(args, env, preprocessor) Build the forward (PF) and backward (PB) policy estimators. Supports tabular or MLP-based modules with optional parameter tying (``--tied`` shares the MLP trunk between PF, PB, and log-state-flow F) and uniform backward policy (``--uniform_pb``). :param args: CLI arguments (uses ``tabular``, ``uniform_pb``, ``tied``, ``hidden_dim``, ``n_hidden``, ``n_noisy_layers``, ``noisy_std_init``). :param env: HyperGrid environment instance. :param preprocessor: State preprocessor. :returns: A ``(pf_estimator, pb_estimator)`` tuple.