tutorials.examples.train_hypergrid_mog ====================================== .. py:module:: tutorials.examples.train_hypergrid_mog .. autoapi-nested-parse:: Experimental: train a Mixture of GFlowNets (MoG) on the HyperGrid environment. This script trains multiple GFlowNet components in parallel using DDP, where each training rank owns one mixture component. A shared classifier ``f_theta`` learns to partition the state space across components so that each component specialises on a different region of the reward landscape. The mixture reward shaping follows: .. math:: \tilde{R}_i(x) = R(x) \cdot f_\theta(x)_i where :math:`f_\theta(x)_i` is the softmax probability that state *x* belongs to component *i*. Each rank trains its own GFlowNet on :math:`\tilde{R}_i`, while ``f_theta`` is trained as a cross-entropy classifier over component IDs and its gradients are all-reduced across ranks so that all ranks share the same partitioner. Example usage (single-node, 3 components + 1 buffer rank):: torchrun --nproc_per_node=4 train_hypergrid_mog.py \ --ndim 2 --height 8 --loss TB --batch_size 64 Key features: - Mixture of GFlowNets with learned state-space partitioning via ``f_theta`` - DDP-based parallel training (one component per training rank) - Supports FM, TB, DB, SubTB, ZVar, and ModifiedDB losses - Optional replay buffers (local and/or remote) with diversity-based prioritization - WandB logging, PyTorch profiler support, and mode-tracking heatmaps Attributes ---------- .. autoapisummary:: tutorials.examples.train_hypergrid_mog.logger tutorials.examples.train_hypergrid_mog.parser Functions --------- .. autoapisummary:: tutorials.examples.train_hypergrid_mog._make_optimizer_for tutorials.examples.train_hypergrid_mog.main tutorials.examples.train_hypergrid_mog.set_up_f_theta_classifier tutorials.examples.train_hypergrid_mog.set_up_fm_gflownet tutorials.examples.train_hypergrid_mog.set_up_gflownet tutorials.examples.train_hypergrid_mog.set_up_logF_estimator tutorials.examples.train_hypergrid_mog.set_up_pb_pf_estimators Module Contents --------------- .. py:function:: _make_optimizer_for(gflownet, args) Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. .. py:data:: logger .. py:function:: main(args) Train a Mixture of GFlowNets on the HyperGrid environment using DDP. High-level flow: 1. **DDP initialization** — detect rank/world-size and create a :class:`DistributedContext`. 2. **Buffer ranks** — if ``--num_remote_buffers > 0``, the last *N* ranks run as dedicated replay-buffer servers and never enter training. 3. **Model setup** — each training rank builds its own GFlowNet component, plus a shared ``f_theta`` classifier for state-space partitioning. 4. **Training loop** — each iteration: a. Sample trajectories with the local GFlowNet. b. Shape rewards using ``f_theta`` (multiply by component probability). c. Compute and backprop the GFlowNet loss (local gradients only). d. Train ``f_theta`` via cross-entropy, then all-reduce its gradients so every rank keeps the same partitioner weights. e. Optimizer steps for both the local GFlowNet and the shared ``f_theta``. 5. **Validation & logging** — periodically compute L1 distance and log to WandB. 6. **Cleanup** — terminate buffer ranks, barrier, return. :param args: Parsed CLI arguments (see ``__main__`` block below). :returns: A dict of final training metrics (loss, l1_dist, modes found, etc.). .. py:data:: parser .. py:function:: set_up_f_theta_classifier(args, env, preprocessor, n_components) Build the shared ``f_theta`` classifier that partitions states across components. The classifier maps preprocessed states to ``n_components`` logits. After softmax, the *i*-th output gives the probability that a state belongs to component *i*, which is used for mixture reward shaping. :param args: Parsed CLI arguments (controls ``--tabular``, ``--hidden_dim``, etc.). :param env: The HyperGrid environment. :param preprocessor: State preprocessor (e.g. :class:`KHotPreprocessor`). :param n_components: Number of mixture components (equal to the number of training ranks). :returns: An ``nn.Module`` producing logits of shape ``(batch, n_components)``. .. py:function:: set_up_fm_gflownet(args, env, preprocessor) Returns a FM GFlowNet. .. py:function:: set_up_gflownet(args, env, preprocessor) Returns a GFlowNet complete with the required estimators. .. py:function:: set_up_logF_estimator(args, env, preprocessor, pf_module) Returns a LogStateFlowEstimator. .. py:function:: set_up_pb_pf_estimators(args, env, preprocessor) Returns a pair of estimators for the forward and backward policies.