tutorials.examples.train_conditional¶
Conditional GFlowNet training on the HyperGrid environment.
This script demonstrates how to train conditional GFlowNets that learn different distributions based on a continuous condition variable on the HyperGrid environment. The condition interpolates between two extremes:
Condition = 0: Uniform distribution (all states get reward R0+R1+R2)
Condition = 1: Original HyperGrid multi-modal distribution
Condition ∈ (0,1): Linear interpolation between uniform and original
During training: - Condition values are sampled uniformly from [0, 1] for each batch - The GFlowNet learns to generate different distributions based on the condition - LogZ is modeled as a function of condition only (not states)
During validation: - Fresh trajectories are sampled for multiple condition values [0, 0.25, 0.5, 0.75, 1] - L1 distance is computed between empirical and true distributions - Mode discovery is tracked for condition=1
Example usage: python train_conditional.py –ndim 2 –height 8 –epsilon 0.1
Attributes¶
Functions¶
Build conditional log flow estimator. |
|
Build conditional policy forward and backward estimators. |
|
|
|
|
|
|
|
|
Build a Trajectory Balance GFlowNet. |
|
Evaluate the conditional sampling distributions with detailed metrics. |
|
|
|
Module Contents¶
- tutorials.examples.train_conditional.DEFAULT_SEED: int = 4444¶
- tutorials.examples.train_conditional.GFN_FNS¶
- tutorials.examples.train_conditional.build_conditional_logF_scalar_estimator(env)¶
Build conditional log flow estimator.
- Parameters:
env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment
- Returns:
A conditional scalar estimator for log flow
- Return type:
- tutorials.examples.train_conditional.build_conditional_pf_pb(env)¶
Build conditional policy forward and backward estimators.
- Parameters:
env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment
- Returns:
A tuple of (forward policy estimator, backward policy estimator)
- Return type:
tuple[gfn.estimators.ConditionalDiscretePolicyEstimator, gfn.estimators.ConditionalDiscretePolicyEstimator]
- tutorials.examples.train_conditional.build_db_gflownet(env)¶
- tutorials.examples.train_conditional.build_db_mod_gflownet(env)¶
- tutorials.examples.train_conditional.build_fm_gflownet(env)¶
- tutorials.examples.train_conditional.build_subTB_gflownet(env)¶
- tutorials.examples.train_conditional.build_tb_gflownet(env)¶
Build a Trajectory Balance GFlowNet.
- Parameters:
env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment
- Returns:
A TBGFlowNet instance
- Return type:
- tutorials.examples.train_conditional.evaluate_conditional_sampling(env, gflownet, device, n_eval_samples=10000)¶
Evaluate the conditional sampling distributions with detailed metrics.
- tutorials.examples.train_conditional.main(args)¶
- tutorials.examples.train_conditional.parser¶
- tutorials.examples.train_conditional.train(env, gflownet, seed, device, n_iterations=10, batch_size=1000, validation_interval=100, validation_samples=20000, lr=0.001, lr_logz=0.01, epsilon=0.0)¶
- Parameters: