bench_get_scores_all ==================== .. py:module:: bench_get_scores_all .. autoapi-nested-parse:: Unified micro-benchmark runner for GFlowNet losses/get_scores. Runs, in order: - Trajectory Balance (TB) loss - Log Partition Variance (LPV) loss - Sub-trajectory Balance (SubTB) get_scores - Detailed Balance (DB) get_scores - Modified Detailed Balance (ModDB) get_scores For each loss, the script reports four timings: original (baseline, frozen copy) original+compile (torch.compile applied to the baseline function) current (eager) current+compile (torch.compile applied to the current function) Two speedups are printed: current/original and current+compile/original. All benchmarks use embedded base sizes, scaled by a single --size-scale multiplier. Correctness is checked once per size before timing and is excluded from the timing loops. Classes ------- .. autoapisummary:: bench_get_scores_all._DBDummyActions bench_get_scores_all._DBDummyEnv bench_get_scores_all._DBDummyStates bench_get_scores_all._DBDummyTransitions bench_get_scores_all._ModDBDummyActions bench_get_scores_all._ModDBDummyEstimator bench_get_scores_all._ModDBDummyStates bench_get_scores_all._ModDBDummyTransitions bench_get_scores_all._ModDBFakeDist bench_get_scores_all._SubTBDummyTrajectories bench_get_scores_all._TBTrajectoriesStub Functions --------- .. autoapisummary:: bench_get_scores_all._build_lpv bench_get_scores_all._build_tb bench_get_scores_all._db_build_model_and_data bench_get_scores_all._db_original_get_scores bench_get_scores_all._format_ms bench_get_scores_all._lpv_original_loss bench_get_scores_all._maybe_compile bench_get_scores_all._moddb_build_model_and_data bench_get_scores_all._moddb_original_get_scores bench_get_scores_all._parse_args bench_get_scores_all._run_db bench_get_scores_all._run_moddb bench_get_scores_all._run_subtb bench_get_scores_all._run_tb_or_lpv bench_get_scores_all._run_with_compile_variants bench_get_scores_all._scale_int bench_get_scores_all._scale_pair bench_get_scores_all._select_dtype bench_get_scores_all._subtb_build_model_and_data bench_get_scores_all._subtb_original_get_scores bench_get_scores_all._tb_original_get_scores bench_get_scores_all._tb_original_loss bench_get_scores_all._time_fn bench_get_scores_all.main Module Contents --------------- .. py:class:: _DBDummyActions(tensor) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:property:: batch_shape :type: torch.Size .. py:attribute:: exit_action .. py:attribute:: is_exit .. py:attribute:: tensor .. py:class:: _DBDummyEnv(log_reward_fn) .. py:attribute:: _log_reward_fn .. py:method:: log_reward(states, conditions = None) .. py:class:: _DBDummyStates(tensor, is_sink_state = None) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:property:: batch_shape :type: torch.Size .. py:property:: device :type: torch.device .. py:attribute:: is_sink_state .. py:attribute:: tensor .. py:class:: _DBDummyTransitions(states, next_states, actions, is_terminating, log_rewards, conditions = None) .. py:method:: __len__() .. py:attribute:: actions .. py:attribute:: conditions :value: None .. py:attribute:: device .. py:attribute:: is_backward :value: False .. py:attribute:: is_terminating .. py:attribute:: log_rewards .. py:attribute:: n_transitions .. py:attribute:: next_states .. py:attribute:: states .. py:class:: _ModDBDummyActions(tensor, is_exit = None) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:attribute:: exit_action .. py:attribute:: is_exit .. py:attribute:: tensor .. py:class:: _ModDBDummyEstimator(log_action, log_exit) .. py:method:: __call__(states, conditions=None) .. py:attribute:: _log_action .. py:attribute:: _log_exit .. py:method:: to_probability_distribution(states, module_output=None) .. py:class:: _ModDBDummyStates(tensor, is_sink_state = None) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:property:: device :type: torch.device .. py:attribute:: is_sink_state .. py:attribute:: tensor .. py:class:: _ModDBDummyTransitions(states, next_states, actions, all_log_rewards, is_backward = False, log_probs = None, has_log_probs = False, conditions = None) .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:attribute:: actions .. py:attribute:: all_log_rewards .. py:attribute:: conditions :value: None .. py:attribute:: device .. py:attribute:: has_log_probs :value: False .. py:attribute:: is_backward :value: False .. py:attribute:: log_probs :value: None .. py:attribute:: n_transitions .. py:attribute:: next_states .. py:attribute:: states .. py:class:: _ModDBFakeDist(log_action, log_exit) .. py:attribute:: _log_action .. py:attribute:: _log_exit .. py:method:: log_prob(action_tensor) .. py:class:: _SubTBDummyTrajectories(terminating_idx, max_length) .. py:method:: __len__() .. py:attribute:: batch_size .. py:attribute:: max_length .. py:attribute:: terminating_idx .. py:class:: _TBTrajectoriesStub(log_rewards, conditions = None) .. py:attribute:: _log_rewards .. py:attribute:: batch_size .. py:attribute:: conditions :value: None .. py:property:: log_rewards :type: torch.Tensor .. py:function:: _build_lpv(T, N, device, dtype) .. py:function:: _build_tb(T, N, device, dtype) .. py:function:: _db_build_model_and_data(n_transitions, seed = 0, device = 'cpu', forward_looking = False) .. py:function:: _db_original_get_scores(model, env, transitions, recalculate_all_logprobs = True) .. py:function:: _format_ms(value) .. py:function:: _lpv_original_loss(model, trajectories, log_pf, log_pb) .. py:function:: _maybe_compile(fn, enabled) .. py:function:: _moddb_build_model_and_data(n_transitions, seed = 0, device = 'cpu') .. py:function:: _moddb_original_get_scores(model, transitions, recalculate_all_logprobs = True) .. py:function:: _parse_args() .. py:function:: _run_db(sizes, device, repeat, compile_enabled, forward_looking) .. py:function:: _run_moddb(sizes, device, repeat, compile_enabled) .. py:function:: _run_subtb(sizes, device, repeat, compile_enabled) .. py:function:: _run_tb_or_lpv(variant, sizes, T, device, dtype, repeat, compile_enabled) .. py:function:: _run_with_compile_variants(eager_fn, compile_enabled) .. py:function:: _scale_int(value, scale) .. py:function:: _scale_pair(pair, scale) .. py:function:: _select_dtype(name, device) .. py:function:: _subtb_build_model_and_data(max_len, n_traj, seed = 0, device = 'cpu') .. py:function:: _subtb_original_get_scores(model, env, trajectories) .. py:function:: _tb_original_get_scores(model, trajectories, log_pf, log_pb) .. py:function:: _tb_original_loss(model, trajectories, log_pf, log_pb) .. py:function:: _time_fn(fn) .. py:function:: main()