tutorials.examples.train_diffusion_rtb

Minimal end-to-end Relative Trajectory Balance (RTB) fine-tuning training script for diffusion models.

  • Prior is pre-trained (auto-runs if the prior checkpoint is missing), so finetuning starts from a learned prior.

  • Posterior is fine-tuned from this prior (pf).

By default, uses the 25→9 GMM posterior target (gmm25_posterior9) with a learnable posterior forward policy and a fixed prior forward policy. Loss is RTB (no backward policy). This script outputs the prior weights alongside plots of samples from both the prior and posterior distributions.

Functions

add_arg_group(parser, specs)

build_forward_estimator(s_dim, num_steps, sigma, ...)

forward_kwargs(args, s_dim, num_steps, sigma, device)

get_debug_metrics(estimator)

Compute gradient norm for a module; return (total_norm, has_nan).

get_exploration_std(iteration[, exploration_factor, ...])

Return a callable exploration std schedule for state-space noise.

main(args)

Runs the posterior finetuning pipeline, including prior pretraining if required.

plot_samples(xs, target, title, save_path[, return_fig])

Contour + scatter plot of samples against the posterior density.

pretrain_prior(args, device, s_dim)

Auto-pretrain the prior if the checkpoint is missing.

resolve_output_paths(args)

Resolve all output paths relative to this script's directory.

Module Contents

tutorials.examples.train_diffusion_rtb.add_arg_group(parser, specs)
Parameters:
  • parser (argparse.ArgumentParser)

  • specs (list[tuple[tuple[str, Ellipsis], dict]])

Return type:

None

tutorials.examples.train_diffusion_rtb.build_forward_estimator(s_dim, num_steps, sigma, harmonics_dim, t_emb_dim, s_emb_dim, hidden_dim, joint_layers, zero_init, learn_variance, clipping, gfn_clip, t_scale, log_var_range, device)
Parameters:
  • s_dim (int)

  • num_steps (int)

  • sigma (float)

  • harmonics_dim (int)

  • t_emb_dim (int)

  • s_emb_dim (int)

  • hidden_dim (int)

  • joint_layers (int)

  • zero_init (bool)

  • learn_variance (bool)

  • clipping (bool)

  • gfn_clip (float)

  • t_scale (float)

  • log_var_range (float)

  • device (torch.device)

Return type:

gfn.estimators.PinnedBrownianMotionForward

tutorials.examples.train_diffusion_rtb.forward_kwargs(args, s_dim, num_steps, sigma, device)
Parameters:
  • args (argparse.Namespace)

  • s_dim (int)

  • num_steps (int)

  • sigma (float)

  • device (torch.device)

Return type:

dict

tutorials.examples.train_diffusion_rtb.get_debug_metrics(estimator)

Compute gradient norm for a module; return (total_norm, has_nan).

Parameters:

estimator (torch.nn.Module)

Return type:

tuple[torch.Tensor, bool]

tutorials.examples.train_diffusion_rtb.get_exploration_std(iteration, exploration_factor=0.1, warm_down_start=500, warm_down_end=4500, device=None, dtype=None)

Return a callable exploration std schedule for state-space noise.

When exploration is enabled, return a step-index function that emits a fixed std for the current training iteration, optionally linearly warmed down after warm_down_start iters toward 0 by warm_down_end iters.

Parameters:
  • iteration (int)

  • exploration_factor (float)

  • warm_down_start (int)

  • warm_down_end (int)

  • device (torch.device | None)

  • dtype (torch.dtype | None)

Return type:

torch.Tensor

tutorials.examples.train_diffusion_rtb.main(args)

Runs the posterior finetuning pipeline, including prior pretraining if required.

Parameters:

args (argparse.Namespace)

Return type:

None

tutorials.examples.train_diffusion_rtb.plot_samples(xs, target, title, save_path, return_fig=False)

Contour + scatter plot of samples against the posterior density.

Parameters:
  • xs (torch.Tensor)

  • title (str)

  • save_path (pathlib.Path | str)

  • return_fig (bool)

tutorials.examples.train_diffusion_rtb.pretrain_prior(args, device, s_dim)

Auto-pretrain the prior if the checkpoint is missing. Saves to args.prior_ckpt_path and returns the resolved path.

Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • s_dim (int)

Return type:

None

tutorials.examples.train_diffusion_rtb.resolve_output_paths(args)

Resolve all output paths relative to this script’s directory.

Parameters:

args (argparse.Namespace)

Return type:

argparse.Namespace