tutorials.examples.train_hypergrid_ddp¶
The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)
python train_hypergrid_ddp.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid_ddp.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}
And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid_ddp.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}
This script uses DDP (DistributedDataParallel) for multi-GPU gradient-parallel training. Launch with torchrun:
torchrun –nproc_per_node=4 train_hypergrid_ddp.py –loss TB –batch_size 64
Each GPU processes a portion of the batch, and gradients are synchronized via all-reduce after every backward pass.
Attributes¶
Functions¶
|
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. |
|
Train a GFlowNet on the HyperGrid environment using DDP. |
|
Build a Flow Matching GFlowNet. |
|
Build a GFlowNet for the requested loss function. |
|
Build the log-state-flow estimator (used by DB and SubTB losses). |
|
Build the forward (PF) and backward (PB) policy estimators. |
Module Contents¶
- tutorials.examples.train_hypergrid_ddp._make_optimizer_for(gflownet, args)¶
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.
- Return type:
torch.optim.Optimizer
- tutorials.examples.train_hypergrid_ddp.logger¶
- tutorials.examples.train_hypergrid_ddp.main(args)¶
Train a GFlowNet on the HyperGrid environment using DDP.
High-level flow:
DDP initialization — detect rank/world-size from torchrun, MPI, or SLURM environment variables and create a
DistributedContext.Buffer ranks — if
--num_remote_buffers > 0, the last N ranks are dedicated replay-buffer servers and never enter the training loop.Model setup — build the GFlowNet, optimizer, and optional replay buffer on each training rank.
Training loop — each iteration: sample trajectories → compute loss → backward → all-reduce gradients across training ranks → optimizer step.
Validation & logging — periodically compute L1 distance to the true distribution and log metrics to WandB.
Cleanup — send termination signals to buffer ranks, barrier, return.
- Parameters:
args – Parsed CLI arguments (see
__main__block below).- Returns:
A dict of final training metrics (loss, l1_dist, modes found, etc.).
- Return type:
dict
- tutorials.examples.train_hypergrid_ddp.parser¶
- tutorials.examples.train_hypergrid_ddp.set_up_fm_gflownet(args, env, preprocessor)¶
Build a Flow Matching GFlowNet.
Flow Matching (FM) only requires a forward policy estimator and learns by matching incoming/outgoing flow at each state.
- Parameters:
args – CLI arguments (uses
tabular,hidden_dim,n_hidden).env – HyperGrid environment instance.
preprocessor – State preprocessor (e.g. KHotPreprocessor).
- Returns:
An FMGFlowNet ready for training.
- tutorials.examples.train_hypergrid_ddp.set_up_gflownet(args, env, preprocessor)¶
Build a GFlowNet for the requested loss function.
Constructs the appropriate estimators and GFlowNet variant:
FM (Flow Matching): forward policy only.
TB (Trajectory Balance): forward + backward policies, learnable logZ.
DB (Detailed Balance): forward + backward policies + log-state-flow F.
SubTB (Sub-Trajectory Balance): like DB, with configurable sub-trajectory weighting.
ZVar (Log-Partition Variance): forward + backward policies.
ModifiedDB: a variant of DB with edge-based rewards.
- Parameters:
args – CLI arguments (
lossselects the variant; architecture flags are forwarded to the estimator builders).env – HyperGrid environment instance.
preprocessor – State preprocessor.
- Returns:
A GFlowNet instance, or
Noneif the loss is unrecognized.
- tutorials.examples.train_hypergrid_ddp.set_up_logF_estimator(args, env, preprocessor, pf_module)¶
Build the log-state-flow estimator (used by DB and SubTB losses).
When
--tiedis set, the MLP trunk is shared with the PF module to reduce the number of trainable parameters.- Parameters:
args – CLI arguments (uses
tabular,tied,hidden_dim,n_hidden).env – HyperGrid environment instance.
preprocessor – State preprocessor.
pf_module – The forward policy’s underlying module (trunk may be reused).
- Returns:
A ScalarEstimator for log F(s).
- tutorials.examples.train_hypergrid_ddp.set_up_pb_pf_estimators(args, env, preprocessor)¶
Build the forward (PF) and backward (PB) policy estimators.
Supports tabular or MLP-based modules with optional parameter tying (
--tiedshares the MLP trunk between PF, PB, and log-state-flow F) and uniform backward policy (--uniform_pb).- Parameters:
args – CLI arguments (uses
tabular,uniform_pb,tied,hidden_dim,n_hidden,n_noisy_layers,noisy_std_init).env – HyperGrid environment instance.
preprocessor – State preprocessor.
- Returns:
A
(pf_estimator, pb_estimator)tuple.