tutorials.examples.train_with_compile¶
Benchmark the runtime impact of different torch.compile strategies on several GFlowNet losses (Trajectory Balance, Detailed Balance, SubTB) and environments.
Four compile modes are compared for each (env, loss) pair:
Pure eager execution
Compile loss only
Compile estimator modules only (try_compile_gflownet)
Compile both the loss wrapper and estimator modules
The script reuses the components defined in the example training scripts: train_hypergrid.py, train_line.py, train_bitsequence_recurrent.py, train_bit_sequences.py, train_graph_ring.py, and train_diffusion_sampler.py.
Attributes¶
Classes¶
Functions¶
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents¶
- tutorials.examples.train_with_compile.COMPILE_MODES: dict[str, CompileMode]¶
- tutorials.examples.train_with_compile.COMPILE_MODE_COLORS: dict[str, str]¶
- class tutorials.examples.train_with_compile.CompileMode¶
- compile_estimators: bool¶
- compile_loss: bool¶
- description: str¶
- key: Literal['eager', 'loss', 'estimators', 'both']¶
- label: str¶
- tutorials.examples.train_with_compile.DEFAULT_COMPILE_ORDER = ['eager', 'loss', 'estimators', 'both']¶
- tutorials.examples.train_with_compile.DEFAULT_FLOW_ORDER = ['tb', 'modified_dbg', 'subtb']¶
- tutorials.examples.train_with_compile.ENVIRONMENT_BENCHMARKS: dict[str, EnvironmentBenchmark]¶
- class tutorials.examples.train_with_compile.EnvironmentBenchmark¶
- builder: Callable[[argparse.Namespace, torch.device, FlowVariant], TrainingComponents]¶
- color: str¶
- description: str¶
- key: Literal['hypergrid', 'line', 'bitseq_recurrent', 'bitseq_mlp', 'diffusion', 'graph_ring']¶
- label: str¶
- supported_flows: list[str]¶
- tutorials.examples.train_with_compile.FLOW_VARIANTS: dict[str, FlowVariant]¶
- class tutorials.examples.train_with_compile.FlowVariant¶
- description: str¶
- key: Literal['tb', 'modified_dbg', 'subtb']¶
- label: str¶
- requires_logf: bool¶
- tutorials.examples.train_with_compile.HYPERGRID_DEFAULTS: Dict[str, Any]¶
- tutorials.examples.train_with_compile.LARGE_MODEL_FIELDS = ['hidden_dim', 'n_hidden', 'line_hidden_dim', 'line_n_hidden', 'bitseq_embedding_dim',...¶
- tutorials.examples.train_with_compile.LOSS_LINE_ALPHA = 0.5¶
- tutorials.examples.train_with_compile.PROJECT_ROOT¶
- class tutorials.examples.train_with_compile.TrainingComponents¶
- env: gfn.env.Env¶
- gflownet: gfn.gflownet.PFBasedGFlowNet¶
- notes: str = ''¶
- optimizer: torch.optim.Optimizer¶
- recalc_logprobs: bool = True¶
- sampler: gfn.samplers.Sampler | None¶
- sampler_kwargs: Dict[str, Any]¶
- use_training_samples: bool = False¶
- tutorials.examples.train_with_compile.VARIANT_COLORS: dict[str, str]¶
- tutorials.examples.train_with_compile._apply_large_model_scaling(args)¶
- Parameters:
args (argparse.Namespace)
- Return type:
None
- tutorials.examples.train_with_compile._build_bitsequence_mlp_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._build_bitsequence_recurrent_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._build_diffusion_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._build_graph_ring_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._build_hypergrid_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._build_line_components(args, device, variant)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
variant (FlowVariant)
- Return type:
- tutorials.examples.train_with_compile._mps_backend_available()¶
- Return type:
bool
- tutorials.examples.train_with_compile._normalize_keys(requested, valid, label)¶
- Parameters:
requested (list[str])
valid (Dict[str, Any])
label (str)
- Return type:
list[str]
- tutorials.examples.train_with_compile._summarize_iteration_times(times)¶
- Parameters:
times (list[float])
- Return type:
tuple[float, float]
- tutorials.examples.train_with_compile._torch_dynamo = None¶
- tutorials.examples.train_with_compile.main()¶
- Return type:
None
- tutorials.examples.train_with_compile.maybe_compile_estimators(components, compile_mode, dynamo_mode)¶
- Parameters:
components (TrainingComponents)
compile_mode (CompileMode)
dynamo_mode (str)
- Return type:
bool
- tutorials.examples.train_with_compile.parse_args()¶
- Return type:
argparse.Namespace
- tutorials.examples.train_with_compile.plot_benchmark(results, output_path, run_label=None)¶
- Parameters:
results (list[dict[str, Any]])
output_path (str)
run_label (str | None)
- Return type:
None
- tutorials.examples.train_with_compile.prepare_loss_fn(gflownet, compile_mode, variant, args)¶
- Parameters:
gflownet (gfn.gflownet.PFBasedGFlowNet)
compile_mode (CompileMode)
variant (FlowVariant)
args (argparse.Namespace)
- Return type:
Callable[[gfn.env.Env, Any, bool], torch.Tensor]
- tutorials.examples.train_with_compile.resolve_device(requested)¶
- Parameters:
requested (str)
- Return type:
torch.device
- tutorials.examples.train_with_compile.run_case(args, device, env_cfg, variant, compile_mode)¶
- Parameters:
args (argparse.Namespace)
device (torch.device)
env_cfg (EnvironmentBenchmark)
variant (FlowVariant)
compile_mode (CompileMode)
- Return type:
dict[str, Any]
- tutorials.examples.train_with_compile.sample_trajectories(components, batch_size)¶
- Parameters:
components (TrainingComponents)
batch_size (int)
- Return type:
Any
- tutorials.examples.train_with_compile.summarize_results(results)¶
- Parameters:
results (list[dict[str, Any]])
- Return type:
None
- tutorials.examples.train_with_compile.synchronize_if_needed(device)¶
- Parameters:
device (torch.device)
- Return type:
None
- tutorials.examples.train_with_compile.training_loop(components, loss_fn, args, *, n_iters, track_time)¶
- Parameters:
components (TrainingComponents)
loss_fn (Callable[[gfn.env.Env, Any, bool], torch.Tensor])
args (argparse.Namespace)
n_iters (int)
track_time (bool)
- Return type:
tuple[float | None, Dict[str, list[float]]]