tutorials.examples.train_diffusion_rtb¶
Minimal end-to-end Relative Trajectory Balance (RTB) fine-tuning training script for diffusion models.
Prior is pre-trained (auto-runs if the prior checkpoint is missing), so finetuning starts from a learned prior.
Posterior is fine-tuned from this prior (pf).
By default, uses the 25→9 GMM posterior target (gmm25_posterior9) with a learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward policy). This script outputs the prior weights alongside plots of samples from both the prior and posterior distributions.
Functions¶
|
|
|
|
|
|
|
Compute gradient norm for a module; return (total_norm, has_nan). |
|
Return a callable exploration std schedule for state-space noise. |
|
Runs the posterior finetuning pipeline, including prior pretraining if required. |
|
Contour + scatter plot of samples against the posterior density. |
|
Auto-pretrain the prior if the checkpoint is missing. |
|
Resolve all output paths relative to this script's directory. |
Module Contents¶
- tutorials.examples.train_diffusion_rtb.add_arg_group(parser, specs)¶
- Parameters:
parser (argparse.ArgumentParser)
specs (list[tuple[tuple[str, Ellipsis], dict]])
- Return type:
None
- tutorials.examples.train_diffusion_rtb.build_forward_estimator(s_dim, num_steps, sigma, harmonics_dim, t_emb_dim, s_emb_dim, hidden_dim, joint_layers, zero_init, learn_variance, clipping, gfn_clip, t_scale, log_var_range, device)¶
- Parameters:
s_dim (int)
num_steps (int)
sigma (float)
harmonics_dim (int)
t_emb_dim (int)
s_emb_dim (int)
hidden_dim (int)
joint_layers (int)
zero_init (bool)
learn_variance (bool)
clipping (bool)
gfn_clip (float)
t_scale (float)
log_var_range (float)
device (torch.device)
- Return type:
- tutorials.examples.train_diffusion_rtb.forward_kwargs(args, s_dim, num_steps, sigma, device)¶
- Parameters:
args (argparse.Namespace)
s_dim (int)
num_steps (int)
sigma (float)
device (torch.device)
- Return type:
dict
- tutorials.examples.train_diffusion_rtb.get_debug_metrics(estimator)¶
Compute gradient norm for a module; return (total_norm, has_nan).
- Parameters:
estimator (torch.nn.Module)
- Return type:
tuple[torch.Tensor, bool]
- tutorials.examples.train_diffusion_rtb.get_exploration_std(iteration, exploration_factor=0.1, warm_down_start=500, warm_down_end=4500, device=None, dtype=None)¶
Return a callable exploration std schedule for state-space noise.
When exploration is enabled, return a step-index function that emits a fixed std for the current training iteration, optionally linearly warmed down after warm_down_start iters toward 0 by warm_down_end iters.
- Parameters:
iteration (int)
exploration_factor (float)
warm_down_start (int)
warm_down_end (int)
device (torch.device | None)
dtype (torch.dtype | None)
- Return type:
torch.Tensor
- tutorials.examples.train_diffusion_rtb.main(args)¶
Runs the posterior finetuning pipeline, including prior pretraining if required.
- Parameters:
args (argparse.Namespace)
- Return type:
None
- tutorials.examples.train_diffusion_rtb.plot_samples(xs, target, title, save_path, return_fig=False)¶
Contour + scatter plot of samples against the posterior density.
- Parameters:
xs (torch.Tensor)
title (str)
save_path (pathlib.Path | str)
return_fig (bool)
- tutorials.examples.train_diffusion_rtb.pretrain_prior(args, device, s_dim)¶
Auto-pretrain the prior if the checkpoint is missing. Saves to args.prior_ckpt_path and returns the resolved path.
- Parameters:
args (argparse.Namespace)
device (torch.device)
s_dim (int)
- Return type:
None
- tutorials.examples.train_diffusion_rtb.resolve_output_paths(args)¶
Resolve all output paths relative to this script’s directory.
- Parameters:
args (argparse.Namespace)
- Return type:
argparse.Namespace