tutorials.examples.train_with_compile ===================================== .. py:module:: tutorials.examples.train_with_compile .. autoapi-nested-parse:: Benchmark the runtime impact of different `torch.compile` strategies on several GFlowNet losses (Trajectory Balance, Detailed Balance, SubTB) and environments. Four compile modes are compared for each (env, loss) pair: 0) Pure eager execution 1) Compile loss only 2) Compile estimator modules only (`try_compile_gflownet`) 3) Compile both the loss wrapper and estimator modules The script reuses the components defined in the example training scripts: `train_hypergrid.py`, `train_line.py`, `train_bitsequence_recurrent.py`, `train_bit_sequences.py`, `train_graph_ring.py`, and `train_diffusion_sampler.py`. Attributes ---------- .. autoapisummary:: tutorials.examples.train_with_compile.COMPILE_MODES tutorials.examples.train_with_compile.COMPILE_MODE_COLORS tutorials.examples.train_with_compile.DEFAULT_COMPILE_ORDER tutorials.examples.train_with_compile.DEFAULT_FLOW_ORDER tutorials.examples.train_with_compile.ENVIRONMENT_BENCHMARKS tutorials.examples.train_with_compile.FLOW_VARIANTS tutorials.examples.train_with_compile.HYPERGRID_DEFAULTS tutorials.examples.train_with_compile.LARGE_MODEL_FIELDS tutorials.examples.train_with_compile.LOSS_LINE_ALPHA tutorials.examples.train_with_compile.PROJECT_ROOT tutorials.examples.train_with_compile.VARIANT_COLORS tutorials.examples.train_with_compile._torch_dynamo Classes ------- .. autoapisummary:: tutorials.examples.train_with_compile.CompileMode tutorials.examples.train_with_compile.EnvironmentBenchmark tutorials.examples.train_with_compile.FlowVariant tutorials.examples.train_with_compile.TrainingComponents Functions --------- .. autoapisummary:: tutorials.examples.train_with_compile._apply_large_model_scaling tutorials.examples.train_with_compile._build_bitsequence_mlp_components tutorials.examples.train_with_compile._build_bitsequence_recurrent_components tutorials.examples.train_with_compile._build_diffusion_components tutorials.examples.train_with_compile._build_graph_ring_components tutorials.examples.train_with_compile._build_hypergrid_components tutorials.examples.train_with_compile._build_line_components tutorials.examples.train_with_compile._mps_backend_available tutorials.examples.train_with_compile._normalize_keys tutorials.examples.train_with_compile._summarize_iteration_times tutorials.examples.train_with_compile.main tutorials.examples.train_with_compile.maybe_compile_estimators tutorials.examples.train_with_compile.parse_args tutorials.examples.train_with_compile.plot_benchmark tutorials.examples.train_with_compile.prepare_loss_fn tutorials.examples.train_with_compile.resolve_device tutorials.examples.train_with_compile.run_case tutorials.examples.train_with_compile.sample_trajectories tutorials.examples.train_with_compile.summarize_results tutorials.examples.train_with_compile.synchronize_if_needed tutorials.examples.train_with_compile.training_loop Module Contents --------------- .. py:data:: COMPILE_MODES :type: dict[str, CompileMode] .. py:data:: COMPILE_MODE_COLORS :type: dict[str, str] .. py:class:: CompileMode .. py:attribute:: compile_estimators :type: bool .. py:attribute:: compile_loss :type: bool .. py:attribute:: description :type: str .. py:attribute:: key :type: Literal['eager', 'loss', 'estimators', 'both'] .. py:attribute:: label :type: str .. py:data:: DEFAULT_COMPILE_ORDER :value: ['eager', 'loss', 'estimators', 'both'] .. py:data:: DEFAULT_FLOW_ORDER :value: ['tb', 'modified_dbg', 'subtb'] .. py:data:: ENVIRONMENT_BENCHMARKS :type: dict[str, EnvironmentBenchmark] .. py:class:: EnvironmentBenchmark .. py:attribute:: builder :type: Callable[[argparse.Namespace, torch.device, FlowVariant], TrainingComponents] .. py:attribute:: color :type: str .. py:attribute:: description :type: str .. py:attribute:: key :type: Literal['hypergrid', 'line', 'bitseq_recurrent', 'bitseq_mlp', 'diffusion', 'graph_ring'] .. py:attribute:: label :type: str .. py:attribute:: supported_flows :type: list[str] .. py:data:: FLOW_VARIANTS :type: dict[str, FlowVariant] .. py:class:: FlowVariant .. py:attribute:: description :type: str .. py:attribute:: key :type: Literal['tb', 'modified_dbg', 'subtb'] .. py:attribute:: label :type: str .. py:attribute:: requires_logf :type: bool .. py:data:: HYPERGRID_DEFAULTS :type: Dict[str, Any] .. py:data:: LARGE_MODEL_FIELDS :value: ['hidden_dim', 'n_hidden', 'line_hidden_dim', 'line_n_hidden', 'bitseq_embedding_dim',... .. py:data:: LOSS_LINE_ALPHA :value: 0.5 .. py:data:: PROJECT_ROOT .. py:class:: TrainingComponents .. py:attribute:: env :type: gfn.env.Env .. py:attribute:: gflownet :type: gfn.gflownet.PFBasedGFlowNet .. py:attribute:: notes :type: str :value: '' .. py:attribute:: optimizer :type: torch.optim.Optimizer .. py:attribute:: recalc_logprobs :type: bool :value: True .. py:attribute:: sampler :type: gfn.samplers.Sampler | None .. py:attribute:: sampler_kwargs :type: Dict[str, Any] .. py:attribute:: use_training_samples :type: bool :value: False .. py:data:: VARIANT_COLORS :type: dict[str, str] .. py:function:: _apply_large_model_scaling(args) .. py:function:: _build_bitsequence_mlp_components(args, device, variant) .. py:function:: _build_bitsequence_recurrent_components(args, device, variant) .. py:function:: _build_diffusion_components(args, device, variant) .. py:function:: _build_graph_ring_components(args, device, variant) .. py:function:: _build_hypergrid_components(args, device, variant) .. py:function:: _build_line_components(args, device, variant) .. py:function:: _mps_backend_available() .. py:function:: _normalize_keys(requested, valid, label) .. py:function:: _summarize_iteration_times(times) .. py:data:: _torch_dynamo :value: None .. py:function:: main() .. py:function:: maybe_compile_estimators(components, compile_mode, dynamo_mode) .. py:function:: parse_args() .. py:function:: plot_benchmark(results, output_path, run_label = None) .. py:function:: prepare_loss_fn(gflownet, compile_mode, variant, args) .. py:function:: resolve_device(requested) .. py:function:: run_case(args, device, env_cfg, variant, compile_mode) .. py:function:: sample_trajectories(components, batch_size) .. py:function:: summarize_results(results) .. py:function:: synchronize_if_needed(device) .. py:function:: training_loop(components, loss_fn, args, *, n_iters, track_time)