tutorials.examples.train_hypergrid ================================== .. py:module:: tutorials.examples.train_hypergrid .. autoapi-nested-parse:: The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259) python train_hypergrid.py --ndim 4 --height 8 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} python train_hypergrid.py --ndim 2 --height 64 --R0 {0.1, 0.01, 0.001} --tied {--uniform_pb} --loss {TB, DB} And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} SELECTIVE AVERAGING: This script also supports selective model averaging for distributed training, where instead of averaging all models, the worst performing models are replaced with averaged weights from the better performing ones. Use the following flags: --use_selective_averaging: Enable selective averaging instead of standard averaging --replacement_ratio 0.2: Replace the worst 20% of models (adjustable 0.0-1.0) --averaging_strategy mean: How to combine good models ("mean", "weighted_mean", "best_only") --momentum 0.0: Momentum factor for combining with previous weights (0.0-1.0, default 0.0) Example with selective averaging: python train_hypergrid.py --distributed --use_selective_averaging --replacement_ratio 0.3 --averaging_strategy mean --momentum 0.1 This script also provides a function `get_exact_P_T` that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization. Attributes ---------- .. autoapisummary:: tutorials.examples.train_hypergrid.logger tutorials.examples.train_hypergrid.parser Classes ------- .. autoapisummary:: tutorials.examples.train_hypergrid.ModesReplayBufferManager Functions --------- .. autoapisummary:: tutorials.examples.train_hypergrid._make_optimizer_for tutorials.examples.train_hypergrid._sample_new_strategy tutorials.examples.train_hypergrid.build_mode_discovery_figure tutorials.examples.train_hypergrid.get_exact_P_T tutorials.examples.train_hypergrid.main tutorials.examples.train_hypergrid.plot_results tutorials.examples.train_hypergrid.set_up_fm_gflownet tutorials.examples.train_hypergrid.set_up_gflownet tutorials.examples.train_hypergrid.set_up_logF_estimator tutorials.examples.train_hypergrid.set_up_pb_pf_estimators Module Contents --------------- .. py:class:: ModesReplayBufferManager(env, rank, num_training_ranks, diverse_replay_buffer = False, capacity = 10000, remote_manager_rank = None, w_retained = 1.0, w_novelty = 0.1, w_reward = 1.0, w_mode_bonus = 10.0, p_norm_novelty = 2.0, cdist_max_bytes = 268435456, ema_decay = 0.5) Bases: :py:obj:`gfn.containers.replay_buffer_manager.ReplayBufferManager` .. py:method:: _compute_metadata() .. py:attribute:: _ema_decay :type: float .. py:attribute:: _score_ema :type: Optional[float] :value: None .. py:attribute:: cdist_max_bytes :value: 268435456 .. py:attribute:: discovered_modes .. py:attribute:: env .. py:attribute:: p_norm_novelty :value: 2.0 .. py:method:: scoring_function(obj) .. py:attribute:: w_mode_bonus :value: 10.0 .. py:attribute:: w_novelty :value: 0.1 .. py:attribute:: w_retained :value: 1.0 .. py:attribute:: w_reward :value: 1.0 .. py:function:: _make_optimizer_for(gflownet, args) Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. .. py:function:: _sample_new_strategy(args, rng) Sample a new exploration strategy by independently sampling each parameter. Each parameter (epsilon, temperature, n_noisy_layers) is sampled from a normal distribution with mean and std specified in args. Values are clamped to valid ranges. :param args: Argument namespace containing mean/std for each parameter: - epsilon, strategy_epsilon_std - temperature, strategy_temperature_std - n_noisy_layers, strategy_n_noisy_layers_std - strategy_noisy_std_init (optional, default 0.5) :param rng: Random number generator instance to use for sampling. :returns: name, epsilon, temperature, n_noisy_layers, noisy_std_init. :rtype: A dict with keys .. py:function:: build_mode_discovery_figure(env, discovered_indices) Build a matplotlib figure showing discovered vs undiscovered mode states. Projects mode states onto 2D planes of the first min(ndim, 3) dimensions. For ndim=1, shows a single row; for ndim=2, a single 2D heatmap; for ndim>=3, three pairwise projections of the first 3 dimensions. Color coding: - Light gray: no mode state at this position - Red: mode state(s) exist but none discovered - Green: at least one mode state discovered Returns None if ``env.all_states`` is unavailable. .. py:function:: get_exact_P_T(env, gflownet) Evaluates the exact terminating state distribution P_T for HyperGrid. For each state s', the terminating state probability is computed as: .. math:: P_T(s') = u(s') P_F(s_f | s') where u(s') satisfies the recursion: .. math:: u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s) with the base case u(s_0) = 1. :param env: The HyperGrid environment :param gflownet: The GFlowNet model :returns: The exact terminating state distribution as a tensor .. py:data:: logger .. py:function:: main(args) Trains a GFlowNet on the Hypergrid Environment, potentially distributed. .. py:data:: parser .. py:function:: plot_results(env, gflownet, l1_distances, args) .. py:function:: set_up_fm_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id) Returns a FM GFlowNet. .. py:function:: set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id, strategy_rng) Returns a GFlowNet complete with the required estimators. .. py:function:: set_up_logF_estimator(args, env, preprocessor, agent_group_list, my_agent_group_id, pf_module) Returns a LogStateFlowEstimator. .. py:function:: set_up_pb_pf_estimators(args, env, preprocessor, agent_group_list, my_agent_group_id) Returns a pair of estimators for the forward and backward policies.