tutorials.examples.train_hypergrid_mog

Experimental: train a Mixture of GFlowNets (MoG) on the HyperGrid environment.

This script trains multiple GFlowNet components in parallel using DDP, where each training rank owns one mixture component. A shared classifier f_theta learns to partition the state space across components so that each component specialises on a different region of the reward landscape.

The mixture reward shaping follows:

\[\tilde{R}_i(x) = R(x) \cdot f_\theta(x)_i\]

where \(f_\theta(x)_i\) is the softmax probability that state x belongs to component i. Each rank trains its own GFlowNet on \(\tilde{R}_i\), while f_theta is trained as a cross-entropy classifier over component IDs and its gradients are all-reduced across ranks so that all ranks share the same partitioner.

Example usage (single-node, 3 components + 1 buffer rank):

torchrun --nproc_per_node=4 train_hypergrid_mog.py \
    --ndim 2 --height 8 --loss TB --batch_size 64

Key features: - Mixture of GFlowNets with learned state-space partitioning via f_theta - DDP-based parallel training (one component per training rank) - Supports FM, TB, DB, SubTB, ZVar, and ModifiedDB losses - Optional replay buffers (local and/or remote) with diversity-based prioritization - WandB logging, PyTorch profiler support, and mode-tracking heatmaps

Attributes

logger

parser

Functions

_make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

main(args)

Train a Mixture of GFlowNets on the HyperGrid environment using DDP.

set_up_f_theta_classifier(args, env, preprocessor, ...)

Build the shared f_theta classifier that partitions states across components.

set_up_fm_gflownet(args, env, preprocessor)

Returns a FM GFlowNet.

set_up_gflownet(args, env, preprocessor)

Returns a GFlowNet complete with the required estimators.

set_up_logF_estimator(args, env, preprocessor, pf_module)

Returns a LogStateFlowEstimator.

set_up_pb_pf_estimators(args, env, preprocessor)

Returns a pair of estimators for the forward and backward policies.

Module Contents

tutorials.examples.train_hypergrid_mog._make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

Return type:

torch.optim.Optimizer

tutorials.examples.train_hypergrid_mog.logger
tutorials.examples.train_hypergrid_mog.main(args)

Train a Mixture of GFlowNets on the HyperGrid environment using DDP.

High-level flow:

  1. DDP initialization — detect rank/world-size and create a DistributedContext.

  2. Buffer ranks — if --num_remote_buffers > 0, the last N ranks run as dedicated replay-buffer servers and never enter training.

  3. Model setup — each training rank builds its own GFlowNet component, plus a shared f_theta classifier for state-space partitioning.

  4. Training loop — each iteration: a. Sample trajectories with the local GFlowNet. b. Shape rewards using f_theta (multiply by component probability). c. Compute and backprop the GFlowNet loss (local gradients only). d. Train f_theta via cross-entropy, then all-reduce its gradients

    so every rank keeps the same partitioner weights.

    1. Optimizer steps for both the local GFlowNet and the shared f_theta.

  5. Validation & logging — periodically compute L1 distance and log to WandB.

  6. Cleanup — terminate buffer ranks, barrier, return.

Parameters:

args – Parsed CLI arguments (see __main__ block below).

Returns:

A dict of final training metrics (loss, l1_dist, modes found, etc.).

Return type:

dict

tutorials.examples.train_hypergrid_mog.parser
tutorials.examples.train_hypergrid_mog.set_up_f_theta_classifier(args, env, preprocessor, n_components)

Build the shared f_theta classifier that partitions states across components.

The classifier maps preprocessed states to n_components logits. After softmax, the i-th output gives the probability that a state belongs to component i, which is used for mixture reward shaping.

Parameters:
  • args – Parsed CLI arguments (controls --tabular, --hidden_dim, etc.).

  • env – The HyperGrid environment.

  • preprocessor – State preprocessor (e.g. KHotPreprocessor).

  • n_components – Number of mixture components (equal to the number of training ranks).

Returns:

An nn.Module producing logits of shape (batch, n_components).

tutorials.examples.train_hypergrid_mog.set_up_fm_gflownet(args, env, preprocessor)

Returns a FM GFlowNet.

tutorials.examples.train_hypergrid_mog.set_up_gflownet(args, env, preprocessor)

Returns a GFlowNet complete with the required estimators.

tutorials.examples.train_hypergrid_mog.set_up_logF_estimator(args, env, preprocessor, pf_module)

Returns a LogStateFlowEstimator.

tutorials.examples.train_hypergrid_mog.set_up_pb_pf_estimators(args, env, preprocessor)

Returns a pair of estimators for the forward and backward policies.