tutorials.examples.train_conditional

Conditional GFlowNet training on the HyperGrid environment.

This script demonstrates how to train conditional GFlowNets that learn different distributions based on a continuous condition variable on the HyperGrid environment. The condition interpolates between two extremes:

  • Condition = 0: Uniform distribution (all states get reward R0+R1+R2)

  • Condition = 1: Original HyperGrid multi-modal distribution

  • Condition ∈ (0,1): Linear interpolation between uniform and original

During training: - Condition values are sampled uniformly from [0, 1] for each batch - The GFlowNet learns to generate different distributions based on the condition - LogZ is modeled as a function of condition only (not states)

During validation: - Fresh trajectories are sampled for multiple condition values [0, 0.25, 0.5, 0.75, 1] - L1 distance is computed between empirical and true distributions - Mode discovery is tracked for condition=1

Example usage: python train_conditional.py –ndim 2 –height 8 –epsilon 0.1

Attributes

DEFAULT_SEED

GFN_FNS

parser

Functions

build_conditional_logF_scalar_estimator(env)

Build conditional log flow estimator.

build_conditional_pf_pb(env)

Build conditional policy forward and backward estimators.

build_db_gflownet(env)

build_db_mod_gflownet(env)

build_fm_gflownet(env)

build_subTB_gflownet(env)

build_tb_gflownet(env)

Build a Trajectory Balance GFlowNet.

evaluate_conditional_sampling(env, gflownet, device[, ...])

Evaluate the conditional sampling distributions with detailed metrics.

main(args)

train(env, gflownet, seed, device[, n_iterations, ...])

Module Contents

tutorials.examples.train_conditional.DEFAULT_SEED: int = 4444
tutorials.examples.train_conditional.GFN_FNS
tutorials.examples.train_conditional.build_conditional_logF_scalar_estimator(env)

Build conditional log flow estimator.

Parameters:

env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment

Returns:

A conditional scalar estimator for log flow

Return type:

gfn.estimators.ConditionalScalarEstimator

tutorials.examples.train_conditional.build_conditional_pf_pb(env)

Build conditional policy forward and backward estimators.

Parameters:

env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment

Returns:

A tuple of (forward policy estimator, backward policy estimator)

Return type:

tuple[gfn.estimators.ConditionalDiscretePolicyEstimator, gfn.estimators.ConditionalDiscretePolicyEstimator]

tutorials.examples.train_conditional.build_db_gflownet(env)
tutorials.examples.train_conditional.build_db_mod_gflownet(env)
tutorials.examples.train_conditional.build_fm_gflownet(env)
tutorials.examples.train_conditional.build_subTB_gflownet(env)
tutorials.examples.train_conditional.build_tb_gflownet(env)

Build a Trajectory Balance GFlowNet.

Parameters:

env (gfn.gym.ConditionalHyperGrid) – The ConditionalHyperGrid environment

Returns:

A TBGFlowNet instance

Return type:

gfn.gflownet.TBGFlowNet

tutorials.examples.train_conditional.evaluate_conditional_sampling(env, gflownet, device, n_eval_samples=10000)

Evaluate the conditional sampling distributions with detailed metrics.

tutorials.examples.train_conditional.main(args)
tutorials.examples.train_conditional.parser
tutorials.examples.train_conditional.train(env, gflownet, seed, device, n_iterations=10, batch_size=1000, validation_interval=100, validation_samples=20000, lr=0.001, lr_logz=0.01, epsilon=0.0)
Parameters:

env (gfn.gym.ConditionalHyperGrid)