tutorials.examples.train_hypergrid¶
The goal of this script is to reproduce some of the published results on the HyperGrid environment. Run one of the following commands to reproduce some of the results in [Trajectory balance: Improved credit assignment in GFlowNets](https://arxiv.org/abs/2201.13259)
python train_hypergrid.py –ndim 4 –height 8 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB} python train_hypergrid.py –ndim 2 –height 64 –R0 {0.1, 0.01, 0.001} –tied {–uniform_pb} –loss {TB, DB}
And run one of the following to reproduce some of the results in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py –ndim {2, 4} –height 12 –R0 {1e-3, 1e-4} –tied –loss {TB, DB, SubTB}
SELECTIVE AVERAGING: This script also supports selective model averaging for distributed training, where instead of averaging all models, the worst performing models are replaced with averaged weights from the better performing ones. Use the following flags:
–use_selective_averaging: Enable selective averaging instead of standard averaging –replacement_ratio 0.2: Replace the worst 20% of models (adjustable 0.0-1.0) –averaging_strategy mean: How to combine good models (“mean”, “weighted_mean”, “best_only”) –momentum 0.0: Momentum factor for combining with previous weights (0.0-1.0, default 0.0)
Example with selective averaging: python train_hypergrid.py –distributed –use_selective_averaging –replacement_ratio 0.3 –averaging_strategy mean –momentum 0.1
This script also provides a function get_exact_P_T that computes the exact terminating state distribution for the HyperGrid environment, which is useful for evaluation and visualization.
Attributes¶
Classes¶
Functions¶
|
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group. |
|
Sample a new exploration strategy by independently sampling each parameter. |
|
Build a matplotlib figure showing discovered vs undiscovered mode states. |
|
Evaluates the exact terminating state distribution P_T for HyperGrid. |
|
Trains a GFlowNet on the Hypergrid Environment, potentially distributed. |
|
|
|
Returns a FM GFlowNet. |
|
Returns a GFlowNet complete with the required estimators. |
|
Returns a LogStateFlowEstimator. |
|
Returns a pair of estimators for the forward and backward policies. |
Module Contents¶
- class tutorials.examples.train_hypergrid.ModesReplayBufferManager(env, rank, num_training_ranks, diverse_replay_buffer=False, capacity=10000, remote_manager_rank=None, w_retained=1.0, w_novelty=0.1, w_reward=1.0, w_mode_bonus=10.0, p_norm_novelty=2.0, cdist_max_bytes=268435456, ema_decay=0.5)¶
Bases:
gfn.containers.replay_buffer_manager.ReplayBufferManager- Parameters:
env (gfn.gym.HyperGrid)
rank (int)
num_training_ranks (int)
diverse_replay_buffer (bool)
capacity (int)
remote_manager_rank (int | None)
w_retained (float)
w_novelty (float)
w_reward (float)
w_mode_bonus (float)
p_norm_novelty (float)
cdist_max_bytes (int)
ema_decay (float)
- _compute_metadata()¶
- Return type:
dict
- _ema_decay: float¶
- _score_ema: float | None = None¶
- cdist_max_bytes = 268435456¶
- discovered_modes¶
- env¶
- p_norm_novelty = 2.0¶
- scoring_function(obj)¶
- Parameters:
obj (gfn.containers.replay_buffer_manager.ContainerUnion)
- Return type:
dict[str, float]
- w_mode_bonus = 10.0¶
- w_novelty = 0.1¶
- w_retained = 1.0¶
- w_reward = 1.0¶
- tutorials.examples.train_hypergrid._make_optimizer_for(gflownet, args)¶
Build a fresh AdamW optimizer for a (re)built GFlowNet with logZ group.
- Return type:
torch.optim.Optimizer
- tutorials.examples.train_hypergrid._sample_new_strategy(args, rng)¶
Sample a new exploration strategy by independently sampling each parameter.
Each parameter (epsilon, temperature, n_noisy_layers) is sampled from a normal distribution with mean and std specified in args. Values are clamped to valid ranges.
- Parameters:
args – Argument namespace containing mean/std for each parameter: - epsilon, strategy_epsilon_std - temperature, strategy_temperature_std - n_noisy_layers, strategy_n_noisy_layers_std - strategy_noisy_std_init (optional, default 0.5)
rng (random.Random) – Random number generator instance to use for sampling.
- Returns:
name, epsilon, temperature, n_noisy_layers, noisy_std_init.
- Return type:
A dict with keys
- tutorials.examples.train_hypergrid.build_mode_discovery_figure(env, discovered_indices)¶
Build a matplotlib figure showing discovered vs undiscovered mode states.
Projects mode states onto 2D planes of the first min(ndim, 3) dimensions. For ndim=1, shows a single row; for ndim=2, a single 2D heatmap; for ndim>=3, three pairwise projections of the first 3 dimensions.
- Color coding:
Light gray: no mode state at this position
Red: mode state(s) exist but none discovered
Green: at least one mode state discovered
Returns None if
env.all_statesis unavailable.- Parameters:
env (gfn.gym.HyperGrid)
discovered_indices (set[int])
- tutorials.examples.train_hypergrid.get_exact_P_T(env, gflownet)¶
Evaluates the exact terminating state distribution P_T for HyperGrid.
For each state s’, the terminating state probability is computed as:
\[P_T(s') = u(s') P_F(s_f | s')\]where u(s’) satisfies the recursion:
\[u(s') = \sum_{s \in \text{Par}(s')} u(s) P_F(s' | s)\]with the base case u(s_0) = 1.
- Parameters:
env (gfn.gym.HyperGrid) – The HyperGrid environment
gflownet (gfn.gflownet.GFlowNet) – The GFlowNet model
- Returns:
The exact terminating state distribution as a tensor
- Return type:
torch.Tensor
- tutorials.examples.train_hypergrid.logger¶
- tutorials.examples.train_hypergrid.main(args)¶
Trains a GFlowNet on the Hypergrid Environment, potentially distributed.
- Return type:
dict
- tutorials.examples.train_hypergrid.parser¶
- tutorials.examples.train_hypergrid.plot_results(env, gflownet, l1_distances, args)¶
- tutorials.examples.train_hypergrid.set_up_fm_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id)¶
Returns a FM GFlowNet.
- tutorials.examples.train_hypergrid.set_up_gflownet(args, env, preprocessor, agent_group_list, my_agent_group_id, strategy_rng)¶
Returns a GFlowNet complete with the required estimators.
- tutorials.examples.train_hypergrid.set_up_logF_estimator(args, env, preprocessor, agent_group_list, my_agent_group_id, pf_module)¶
Returns a LogStateFlowEstimator.
- tutorials.examples.train_hypergrid.set_up_pb_pf_estimators(args, env, preprocessor, agent_group_list, my_agent_group_id)¶
Returns a pair of estimators for the forward and backward policies.