tutorials.examples.train_conditional ==================================== .. py:module:: tutorials.examples.train_conditional .. autoapi-nested-parse:: Conditional GFlowNet training on the HyperGrid environment. This script demonstrates how to train conditional GFlowNets that learn different distributions based on a continuous condition variable on the HyperGrid environment. The condition interpolates between two extremes: - Condition = 0: Uniform distribution (all states get reward R0+R1+R2) - Condition = 1: Original HyperGrid multi-modal distribution - Condition ∈ (0,1): Linear interpolation between uniform and original During training: - Condition values are sampled uniformly from [0, 1] for each batch - The GFlowNet learns to generate different distributions based on the condition - LogZ is modeled as a function of condition only (not states) During validation: - Fresh trajectories are sampled for multiple condition values [0, 0.25, 0.5, 0.75, 1] - L1 distance is computed between empirical and true distributions - Mode discovery is tracked for condition=1 Example usage: python train_conditional.py --ndim 2 --height 8 --epsilon 0.1 Attributes ---------- .. autoapisummary:: tutorials.examples.train_conditional.DEFAULT_SEED tutorials.examples.train_conditional.GFN_FNS tutorials.examples.train_conditional.parser Functions --------- .. autoapisummary:: tutorials.examples.train_conditional.build_conditional_logF_scalar_estimator tutorials.examples.train_conditional.build_conditional_pf_pb tutorials.examples.train_conditional.build_db_gflownet tutorials.examples.train_conditional.build_db_mod_gflownet tutorials.examples.train_conditional.build_fm_gflownet tutorials.examples.train_conditional.build_subTB_gflownet tutorials.examples.train_conditional.build_tb_gflownet tutorials.examples.train_conditional.evaluate_conditional_sampling tutorials.examples.train_conditional.main tutorials.examples.train_conditional.train Module Contents --------------- .. py:data:: DEFAULT_SEED :type: int :value: 4444 .. py:data:: GFN_FNS .. py:function:: build_conditional_logF_scalar_estimator(env) Build conditional log flow estimator. :param env: The ConditionalHyperGrid environment :returns: A conditional scalar estimator for log flow .. py:function:: build_conditional_pf_pb(env) Build conditional policy forward and backward estimators. :param env: The ConditionalHyperGrid environment :returns: A tuple of (forward policy estimator, backward policy estimator) .. py:function:: build_db_gflownet(env) .. py:function:: build_db_mod_gflownet(env) .. py:function:: build_fm_gflownet(env) .. py:function:: build_subTB_gflownet(env) .. py:function:: build_tb_gflownet(env) Build a Trajectory Balance GFlowNet. :param env: The ConditionalHyperGrid environment :returns: A TBGFlowNet instance .. py:function:: evaluate_conditional_sampling(env, gflownet, device, n_eval_samples=10000) Evaluate the conditional sampling distributions with detailed metrics. .. py:function:: main(args) .. py:data:: parser .. py:function:: train(env, gflownet, seed, device, n_iterations=10, batch_size=1000, validation_interval=100, validation_samples=20000, lr=0.001, lr_logz=0.01, epsilon=0.0)