tutorials.examples.train_hypergrid_ddp

The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)

python train_hypergrid_ddp.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid_ddp.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}

And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid_ddp.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}

This script uses DDP (DistributedDataParallel) for multi-GPU gradient-parallel training. Launch with torchrun:

torchrun –nproc_per_node=4 train_hypergrid_ddp.py –loss TB –batch_size 64

Each GPU processes a portion of the batch, and gradients are synchronized via all-reduce after every backward pass.

Attributes

logger

parser

Functions

_make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

main(args)

Train a GFlowNet on the HyperGrid environment using DDP.

set_up_fm_gflownet(args, env, preprocessor)

Build a Flow Matching GFlowNet.

set_up_gflownet(args, env, preprocessor)

Build a GFlowNet for the requested loss function.

set_up_logF_estimator(args, env, preprocessor, pf_module)

Build the log-state-flow estimator (used by DB and SubTB losses).

set_up_pb_pf_estimators(args, env, preprocessor)

Build the forward (PF) and backward (PB) policy estimators.

Module Contents

tutorials.examples.train_hypergrid_ddp._make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

Return type:

torch.optim.Optimizer

tutorials.examples.train_hypergrid_ddp.logger
tutorials.examples.train_hypergrid_ddp.main(args)

Train a GFlowNet on the HyperGrid environment using DDP.

High-level flow:

  1. DDP initialization — detect rank/world-size from torchrun, MPI, or SLURM environment variables and create a DistributedContext.

  2. Buffer ranks — if --num_remote_buffers > 0, the last N ranks are dedicated replay-buffer servers and never enter the training loop.

  3. Model setup — build the GFlowNet, optimizer, and optional replay buffer on each training rank.

  4. Training loop — each iteration: sample trajectories → compute loss → backward → all-reduce gradients across training ranks → optimizer step.

  5. Validation & logging — periodically compute L1 distance to the true distribution and log metrics to WandB.

  6. Cleanup — send termination signals to buffer ranks, barrier, return.

Parameters:

args – Parsed CLI arguments (see __main__ block below).

Returns:

A dict of final training metrics (loss, l1_dist, modes found, etc.).

Return type:

dict

tutorials.examples.train_hypergrid_ddp.parser
tutorials.examples.train_hypergrid_ddp.set_up_fm_gflownet(args, env, preprocessor)

Build a Flow Matching GFlowNet.

Flow Matching (FM) only requires a forward policy estimator and learns by matching incoming/outgoing flow at each state.

Parameters:
  • args – CLI arguments (uses tabular, hidden_dim, n_hidden).

  • env – HyperGrid environment instance.

  • preprocessor – State preprocessor (e.g. KHotPreprocessor).

Returns:

An FMGFlowNet ready for training.

tutorials.examples.train_hypergrid_ddp.set_up_gflownet(args, env, preprocessor)

Build a GFlowNet for the requested loss function.

Constructs the appropriate estimators and GFlowNet variant:

  • FM (Flow Matching): forward policy only.

  • TB (Trajectory Balance): forward + backward policies, learnable logZ.

  • DB (Detailed Balance): forward + backward policies + log-state-flow F.

  • SubTB (Sub-Trajectory Balance): like DB, with configurable sub-trajectory weighting.

  • ZVar (Log-Partition Variance): forward + backward policies.

  • ModifiedDB: a variant of DB with edge-based rewards.

Parameters:
  • args – CLI arguments (loss selects the variant; architecture flags are forwarded to the estimator builders).

  • env – HyperGrid environment instance.

  • preprocessor – State preprocessor.

Returns:

A GFlowNet instance, or None if the loss is unrecognized.

tutorials.examples.train_hypergrid_ddp.set_up_logF_estimator(args, env, preprocessor, pf_module)

Build the log-state-flow estimator (used by DB and SubTB losses).

When --tied is set, the MLP trunk is shared with the PF module to reduce the number of trainable parameters.

Parameters:
  • args – CLI arguments (uses tabular, tied, hidden_dim, n_hidden).

  • env – HyperGrid environment instance.

  • preprocessor – State preprocessor.

  • pf_module – The forward policy’s underlying module (trunk may be reused).

Returns:

A ScalarEstimator for log F(s).

tutorials.examples.train_hypergrid_ddp.set_up_pb_pf_estimators(args, env, preprocessor)

Build the forward (PF) and backward (PB) policy estimators.

Supports tabular or MLP-based modules with optional parameter tying (--tied shares the MLP trunk between PF, PB, and log-state-flow F) and uniform backward policy (--uniform_pb).

Parameters:
  • args – CLI arguments (uses tabular, uniform_pb, tied, hidden_dim, n_hidden, n_noisy_layers, noisy_std_init).

  • env – HyperGrid environment instance.

  • preprocessor – State preprocessor.

Returns:

A (pf_estimator, pb_estimator) tuple.