tutorials.examples.train_diffusion_rtb ====================================== .. py:module:: tutorials.examples.train_diffusion_rtb .. autoapi-nested-parse:: Minimal end-to-end Relative Trajectory Balance (RTB) fine-tuning training script for diffusion models. - Prior is pre-trained (auto-runs if the prior checkpoint is missing), so finetuning starts from a learned prior. - Posterior is fine-tuned from this prior (pf). By default, uses the 25→9 GMM posterior target (`gmm25_posterior9`) with a learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward policy). This script outputs the prior weights alongside plots of samples from both the prior and posterior distributions. Functions --------- .. autoapisummary:: tutorials.examples.train_diffusion_rtb.add_arg_group tutorials.examples.train_diffusion_rtb.build_forward_estimator tutorials.examples.train_diffusion_rtb.forward_kwargs tutorials.examples.train_diffusion_rtb.get_debug_metrics tutorials.examples.train_diffusion_rtb.get_exploration_std tutorials.examples.train_diffusion_rtb.main tutorials.examples.train_diffusion_rtb.plot_samples tutorials.examples.train_diffusion_rtb.pretrain_prior tutorials.examples.train_diffusion_rtb.resolve_output_paths Module Contents --------------- .. py:function:: add_arg_group(parser, specs) .. py:function:: build_forward_estimator(s_dim, num_steps, sigma, harmonics_dim, t_emb_dim, s_emb_dim, hidden_dim, joint_layers, zero_init, learn_variance, clipping, gfn_clip, t_scale, log_var_range, device) .. py:function:: forward_kwargs(args, s_dim, num_steps, sigma, device) .. py:function:: get_debug_metrics(estimator) Compute gradient norm for a module; return (total_norm, has_nan). .. py:function:: get_exploration_std(iteration, exploration_factor = 0.1, warm_down_start = 500, warm_down_end = 4500, device = None, dtype = None) Return a callable exploration std schedule for state-space noise. When exploration is enabled, return a step-index function that emits a fixed std for the current training iteration, optionally linearly warmed down after warm_down_start iters toward 0 by warm_down_end iters. .. py:function:: main(args) Runs the posterior finetuning pipeline, including prior pretraining if required. .. py:function:: plot_samples(xs, target, title, save_path, return_fig = False) Contour + scatter plot of samples against the posterior density. .. py:function:: pretrain_prior(args, device, s_dim) Auto-pretrain the prior if the checkpoint is missing. Saves to args.prior_ckpt_path and returns the resolved path. .. py:function:: resolve_output_paths(args) Resolve all output paths relative to this script's directory.