tutorials.examples.train_hypergrid_mog¶
Experimental: train a Mixture of GFlowNets (MoG) on the HyperGrid environment.
This script trains multiple GFlowNet components in parallel using DDP, where each
training rank owns one mixture component. A shared classifier f_theta learns to
partition the state space across components so that each component specialises on a
different region of the reward landscape.
The mixture reward shaping follows:
where \(f_\theta(x)_i\) is the softmax probability that state x belongs to
component i. Each rank trains its own GFlowNet on \(\tilde{R}_i\), while
f_theta is trained as a cross-entropy classifier over component IDs and its
gradients are all-reduced across ranks so that all ranks share the same partitioner.
Example usage (single-node, 3 components + 1 buffer rank):
torchrun --nproc_per_node=4 train_hypergrid_mog.py \
--ndim 2 --height 8 --loss TB --batch_size 64
Key features:
- Mixture of GFlowNets with learned state-space partitioning via f_theta
- DDP-based parallel training (one component per training rank)
- Supports FM, TB, DB, SubTB, ZVar, and ModifiedDB losses
- Optional replay buffers (local and/or remote) with diversity-based prioritization
- WandB logging, PyTorch profiler support, and mode-tracking heatmaps
Attributes¶
Functions¶
|
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. |
|
Train a Mixture of GFlowNets on the HyperGrid environment using DDP. |
|
Build the shared |
|
Returns a FM GFlowNet. |
|
Returns a GFlowNet complete with the required estimators. |
|
Returns a LogStateFlowEstimator. |
|
Returns a pair of estimators for the forward and backward policies. |
Module Contents¶
- tutorials.examples.train_hypergrid_mog._make_optimizer_for(gflownet, args)¶
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.
- Return type:
torch.optim.Optimizer
- tutorials.examples.train_hypergrid_mog.logger¶
- tutorials.examples.train_hypergrid_mog.main(args)¶
Train a Mixture of GFlowNets on the HyperGrid environment using DDP.
High-level flow:
DDP initialization — detect rank/world-size and create a
DistributedContext.Buffer ranks — if
--num_remote_buffers > 0, the last N ranks run as dedicated replay-buffer servers and never enter training.Model setup — each training rank builds its own GFlowNet component, plus a shared
f_thetaclassifier for state-space partitioning.Training loop — each iteration: a. Sample trajectories with the local GFlowNet. b. Shape rewards using
f_theta(multiply by component probability). c. Compute and backprop the GFlowNet loss (local gradients only). d. Trainf_thetavia cross-entropy, then all-reduce its gradientsso every rank keeps the same partitioner weights.
Optimizer steps for both the local GFlowNet and the shared
f_theta.
Validation & logging — periodically compute L1 distance and log to WandB.
Cleanup — terminate buffer ranks, barrier, return.
- Parameters:
args – Parsed CLI arguments (see
__main__block below).- Returns:
A dict of final training metrics (loss, l1_dist, modes found, etc.).
- Return type:
dict
- tutorials.examples.train_hypergrid_mog.parser¶
- tutorials.examples.train_hypergrid_mog.set_up_f_theta_classifier(args, env, preprocessor, n_components)¶
Build the shared
f_thetaclassifier that partitions states across components.The classifier maps preprocessed states to
n_componentslogits. After softmax, the i-th output gives the probability that a state belongs to component i, which is used for mixture reward shaping.- Parameters:
args – Parsed CLI arguments (controls
--tabular,--hidden_dim, etc.).env – The HyperGrid environment.
preprocessor – State preprocessor (e.g.
KHotPreprocessor).n_components – Number of mixture components (equal to the number of training ranks).
- Returns:
An
nn.Moduleproducing logits of shape(batch, n_components).
- tutorials.examples.train_hypergrid_mog.set_up_fm_gflownet(args, env, preprocessor)¶
Returns a FM GFlowNet.
- tutorials.examples.train_hypergrid_mog.set_up_gflownet(args, env, preprocessor)¶
Returns a GFlowNet complete with the required estimators.
- tutorials.examples.train_hypergrid_mog.set_up_logF_estimator(args, env, preprocessor, pf_module)¶
Returns a LogStateFlowEstimator.
- tutorials.examples.train_hypergrid_mog.set_up_pb_pf_estimators(args, env, preprocessor)¶
Returns a pair of estimators for the forward and backward policies.