tutorials.examples.train_with_compile

Benchmark the runtime impact of different torch.compile strategies on several GFlowNet losses (Trajectory Balance, Detailed Balance, SubTB) and environments.

Four compile modes are compared for each (env, loss) pair:

  1. Pure eager execution

  2. Compile loss only

  3. Compile estimator modules only (try_compile_gflownet)

  4. Compile both the loss wrapper and estimator modules

The script reuses the components defined in the example training scripts: train_hypergrid.py, train_line.py, train_bitsequence_recurrent.py, train_bit_sequences.py, train_graph_ring.py, and train_diffusion_sampler.py.

Attributes

COMPILE_MODES

COMPILE_MODE_COLORS

DEFAULT_COMPILE_ORDER

DEFAULT_FLOW_ORDER

ENVIRONMENT_BENCHMARKS

FLOW_VARIANTS

HYPERGRID_DEFAULTS

LARGE_MODEL_FIELDS

LOSS_LINE_ALPHA

PROJECT_ROOT

VARIANT_COLORS

_torch_dynamo

Classes

CompileMode

EnvironmentBenchmark

FlowVariant

TrainingComponents

Functions

_apply_large_model_scaling(args)

_build_bitsequence_mlp_components(args, device, variant)

_build_bitsequence_recurrent_components(args, device, ...)

_build_diffusion_components(args, device, variant)

_build_graph_ring_components(args, device, variant)

_build_hypergrid_components(args, device, variant)

_build_line_components(args, device, variant)

_mps_backend_available()

_normalize_keys(requested, valid, label)

_summarize_iteration_times(times)

main()

maybe_compile_estimators(components, compile_mode, ...)

parse_args()

plot_benchmark(results, output_path[, run_label])

prepare_loss_fn(gflownet, compile_mode, variant, args)

resolve_device(requested)

run_case(args, device, env_cfg, variant, compile_mode)

sample_trajectories(components, batch_size)

summarize_results(results)

synchronize_if_needed(device)

training_loop(components, loss_fn, args, *, n_iters, ...)

Module Contents

tutorials.examples.train_with_compile.COMPILE_MODES: dict[str, CompileMode]
tutorials.examples.train_with_compile.COMPILE_MODE_COLORS: dict[str, str]
class tutorials.examples.train_with_compile.CompileMode
compile_estimators: bool
compile_loss: bool
description: str
key: Literal['eager', 'loss', 'estimators', 'both']
label: str
tutorials.examples.train_with_compile.DEFAULT_COMPILE_ORDER = ['eager', 'loss', 'estimators', 'both']
tutorials.examples.train_with_compile.DEFAULT_FLOW_ORDER = ['tb', 'modified_dbg', 'subtb']
tutorials.examples.train_with_compile.ENVIRONMENT_BENCHMARKS: dict[str, EnvironmentBenchmark]
class tutorials.examples.train_with_compile.EnvironmentBenchmark
builder: Callable[[argparse.Namespace, torch.device, FlowVariant], TrainingComponents]
color: str
description: str
key: Literal['hypergrid', 'line', 'bitseq_recurrent', 'bitseq_mlp', 'diffusion', 'graph_ring']
label: str
supported_flows: list[str]
tutorials.examples.train_with_compile.FLOW_VARIANTS: dict[str, FlowVariant]
class tutorials.examples.train_with_compile.FlowVariant
description: str
key: Literal['tb', 'modified_dbg', 'subtb']
label: str
requires_logf: bool
tutorials.examples.train_with_compile.HYPERGRID_DEFAULTS: Dict[str, Any]
tutorials.examples.train_with_compile.LARGE_MODEL_FIELDS = ['hidden_dim', 'n_hidden', 'line_hidden_dim', 'line_n_hidden', 'bitseq_embedding_dim',...
tutorials.examples.train_with_compile.LOSS_LINE_ALPHA = 0.5
tutorials.examples.train_with_compile.PROJECT_ROOT
class tutorials.examples.train_with_compile.TrainingComponents
env: gfn.env.Env
gflownet: gfn.gflownet.PFBasedGFlowNet
notes: str = ''
optimizer: torch.optim.Optimizer
recalc_logprobs: bool = True
sampler: gfn.samplers.Sampler | None
sampler_kwargs: Dict[str, Any]
use_training_samples: bool = False
tutorials.examples.train_with_compile.VARIANT_COLORS: dict[str, str]
tutorials.examples.train_with_compile._apply_large_model_scaling(args)
Parameters:

args (argparse.Namespace)

Return type:

None

tutorials.examples.train_with_compile._build_bitsequence_mlp_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._build_bitsequence_recurrent_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._build_diffusion_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._build_graph_ring_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._build_hypergrid_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._build_line_components(args, device, variant)
Parameters:
  • args (argparse.Namespace)

  • device (torch.device)

  • variant (FlowVariant)

Return type:

TrainingComponents

tutorials.examples.train_with_compile._mps_backend_available()
Return type:

bool

tutorials.examples.train_with_compile._normalize_keys(requested, valid, label)
Parameters:
  • requested (list[str])

  • valid (Dict[str, Any])

  • label (str)

Return type:

list[str]

tutorials.examples.train_with_compile._summarize_iteration_times(times)
Parameters:

times (list[float])

Return type:

tuple[float, float]

tutorials.examples.train_with_compile._torch_dynamo = None
tutorials.examples.train_with_compile.main()
Return type:

None

tutorials.examples.train_with_compile.maybe_compile_estimators(components, compile_mode, dynamo_mode)
Parameters:
Return type:

bool

tutorials.examples.train_with_compile.parse_args()
Return type:

argparse.Namespace

tutorials.examples.train_with_compile.plot_benchmark(results, output_path, run_label=None)
Parameters:
  • results (list[dict[str, Any]])

  • output_path (str)

  • run_label (str | None)

Return type:

None

tutorials.examples.train_with_compile.prepare_loss_fn(gflownet, compile_mode, variant, args)
Parameters:
Return type:

Callable[[gfn.env.Env, Any, bool], torch.Tensor]

tutorials.examples.train_with_compile.resolve_device(requested)
Parameters:

requested (str)

Return type:

torch.device

tutorials.examples.train_with_compile.run_case(args, device, env_cfg, variant, compile_mode)
Parameters:
Return type:

dict[str, Any]

tutorials.examples.train_with_compile.sample_trajectories(components, batch_size)
Parameters:
Return type:

Any

tutorials.examples.train_with_compile.summarize_results(results)
Parameters:

results (list[dict[str, Any]])

Return type:

None

tutorials.examples.train_with_compile.synchronize_if_needed(device)
Parameters:

device (torch.device)

Return type:

None

tutorials.examples.train_with_compile.training_loop(components, loss_fn, args, *, n_iters, track_time)
Parameters:
  • components (TrainingComponents)

  • loss_fn (Callable[[gfn.env.Env, Any, bool], torch.Tensor])

  • args (argparse.Namespace)

  • n_iters (int)

  • track_time (bool)

Return type:

tuple[float | None, Dict[str, list[float]]]