tutorials.examples.train_hypergrid

The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)

python train_hypergrid.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}

And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}

SELECTIVE AVERAGING: This script also supports selective model averaging for distributed training, where instead of averaging all models, the worst performing models are replaced with averaged weights from the better performing ones. Use the following flags:

–use_selective_averaging: Enable selective averaging instead of standard averaging –replacement_ratio 0.2: Replace the worst 20% of models (adjustable 0.0-1.0) –averaging_strategy mean: How to combine good models (“mean”, “weighted_mean”, “best_only”) –momentum 0.0: Momentum factor for combining with previous weights (0.0-1.0, default 0.0)

Example with selective averaging: python train_hypergrid.py –distributed –use_selective_averaging –replacement_ratio 0.3 –averaging_strategy mean –momentum 0.1

This script also provides a function get_exact_P_T that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization.

Attributes

logger

parser

Classes

ModesReplayBufferManager

Functions

_make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

_sample_new_strategy(args, rng)

Sample a new exploration strategy by independently sampling each parameter.

build_mode_discovery_figure(env, discovered_indices)

Build a matplotlib figure showing discovered vs undiscovered mode states.

get_exact_P_T(env, gflownet)

Evaluates the exact terminating state distribution P_T for HyperGrid.

main(args)

Trains a GFlowNet on the Hypergrid Environment, potentially distributed.

plot_results(env, gflownet, l1_distances, args)

set_up_fm_gflownet(args, env, preprocessor, ...)

Returns a FM GFlowNet.

set_up_gflownet(args, env, preprocessor, ...)

Returns a GFlowNet complete with the required estimators.

set_up_logF_estimator(args, env, preprocessor, ...)

Returns a LogStateFlowEstimator.

set_up_pb_pf_estimators(args, env, preprocessor, ...)

Returns a pair of estimators for the forward and backward policies.

Module Contents

class tutorials.examples.train_hypergrid.ModesReplayBufferManager(env, rank, num_training_ranks, diverse_replay_buffer=False, capacity=10000, remote_manager_rank=None, w_retained=1.0, w_novelty=0.1, w_reward=1.0, w_mode_bonus=10.0, p_norm_novelty=2.0, cdist_max_bytes=268435456, ema_decay=0.5)

Bases: gfn.containers.replay_buffer_manager.ReplayBufferManager

Parameters:
  • env (gfn.gym.HyperGrid)

  • rank (int)

  • num_training_ranks (int)

  • diverse_replay_buffer (bool)

  • capacity (int)

  • remote_manager_rank (int | None)

  • w_retained (float)

  • w_novelty (float)

  • w_reward (float)

  • w_mode_bonus (float)

  • p_norm_novelty (float)

  • cdist_max_bytes (int)

  • ema_decay (float)

_compute_metadata()
Return type:

dict

_ema_decay: float
_score_ema: float | None = None
cdist_max_bytes = 268435456
discovered_modes
env
p_norm_novelty = 2.0
scoring_function(obj)
Parameters:

obj (gfn.containers.replay_buffer_manager.ContainerUnion)

Return type:

dict[str, float]

w_mode_bonus = 10.0
w_novelty = 0.1
w_retained = 1.0
w_reward = 1.0
tutorials.examples.train_hypergrid._make_optimizer_for(gflownet, args)

Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.

Return type:

torch.optim.Optimizer

tutorials.examples.train_hypergrid._sample_new_strategy(args, rng)

Sample a new exploration strategy by independently sampling each parameter.

Each parameter (epsilon, temperature, n_noisy_layers) is sampled from a normal distribution with mean and std specified in args. Values are clamped to valid ranges.

Parameters:
  • args – Argument namespace containing mean/std for each parameter: - epsilon, strategy_epsilon_std - temperature, strategy_temperature_std - n_noisy_layers, strategy_n_noisy_layers_std - strategy_noisy_std_init (optional, default 0.5)

  • rng (random.Random) – Random number generator instance to use for sampling.

Returns:

name, epsilon, temperature, n_noisy_layers, noisy_std_init.

Return type:

A dict with keys

tutorials.examples.train_hypergrid.build_mode_discovery_figure(env, discovered_indices)

Build a matplotlib figure showing discovered vs undiscovered mode states.

Projects mode states onto 2D planes of the first min(ndim, 3) dimensions. For ndim=1, shows a single row; for ndim=2, a single 2D heatmap; for ndim>=3, three pairwise projections of the first 3 dimensions.

Color coding:
  • Light gray: no mode state at this position

  • Red: mode state(s) exist but none discovered

  • Green: at least one mode state discovered

Returns None if env.all_states is unavailable.

Parameters:
tutorials.examples.train_hypergrid.get_exact_P_T(env, gflownet)

Evaluates the exact terminating state distribution P_T for HyperGrid.

For each state s’, the terminating state probability is computed as:

\[P_T(s') = u(s') P_F(s_f | s')\]

where u(s’) satisfies the recursion:

\[u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s)\]

with the base case u(s_0) = 1.

Parameters:
Returns:

The exact terminating state distribution as a tensor

Return type:

torch.Tensor

tutorials.examples.train_hypergrid.logger
tutorials.examples.train_hypergrid.main(args)

Trains a GFlowNet on the Hypergrid Environment, potentially distributed.

Return type:

dict

tutorials.examples.train_hypergrid.parser
tutorials.examples.train_hypergrid.plot_results(env, gflownet, l1_distances, args)
tutorials.examples.train_hypergrid.set_up_fm_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id)

Returns a FM GFlowNet.

tutorials.examples.train_hypergrid.set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id, strategy_rng)

Returns a GFlowNet complete with the required estimators.

tutorials.examples.train_hypergrid.set_up_logF_estimator(args, env, preprocessor, agent_group_list, my_agent_group_id, pf_module)

Returns a LogStateFlowEstimator.

tutorials.examples.train_hypergrid.set_up_pb_pf_estimators(args, env, preprocessor, agent_group_list, my_agent_group_id)

Returns a pair of estimators for the forward and backward policies.