bench_get_scores_all¶
Unified micro-benchmark runner for GFlowNet losses/get_scores.
Runs, in order: - Trajectory Balance (TB) loss - Log Partition Variance (LPV) loss - Sub-trajectory Balance (SubTB) get_scores - Detailed Balance (DB) get_scores - Modified Detailed Balance (ModDB) get_scores
- For each loss, the script reports four timings:
original (baseline, frozen copy) original+compile (torch.compile applied to the baseline function) current (eager) current+compile (torch.compile applied to the current function)
Two speedups are printed: current/original and current+compile/original.
All benchmarks use embedded base sizes, scaled by a single –size-scale multiplier. Correctness is checked once per size before timing and is excluded from the timing loops.
Classes¶
Functions¶
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Module Contents¶
- class bench_get_scores_all._DBDummyActions(tensor)¶
- Parameters:
tensor (torch.Tensor)
- __getitem__(idx)¶
- Return type:
- __len__()¶
- Return type:
int
- property batch_shape: torch.Size¶
- Return type:
torch.Size
- exit_action¶
- is_exit¶
- tensor¶
- class bench_get_scores_all._DBDummyEnv(log_reward_fn)¶
- Parameters:
log_reward_fn (Callable[[Any, Any | None], torch.Tensor])
- _log_reward_fn¶
- log_reward(states, conditions=None)¶
- Parameters:
states (Any)
conditions (Any | None)
- Return type:
torch.Tensor
- class bench_get_scores_all._DBDummyStates(tensor, is_sink_state=None)¶
- Parameters:
tensor (torch.Tensor)
is_sink_state (torch.Tensor | None)
- __getitem__(idx)¶
- Return type:
- __len__()¶
- Return type:
int
- property batch_shape: torch.Size¶
- Return type:
torch.Size
- property device: torch.device¶
- Return type:
torch.device
- is_sink_state¶
- tensor¶
- class bench_get_scores_all._DBDummyTransitions(states, next_states, actions, is_terminating, log_rewards, conditions=None)¶
- Parameters:
states (_DBDummyStates)
next_states (_DBDummyStates)
actions (_DBDummyActions)
is_terminating (torch.Tensor)
log_rewards (torch.Tensor)
conditions (torch.Tensor | None)
- __len__()¶
- Return type:
int
- actions¶
- conditions = None¶
- device¶
- is_backward = False¶
- is_terminating¶
- log_rewards¶
- n_transitions¶
- next_states¶
- states¶
- class bench_get_scores_all._ModDBDummyActions(tensor, is_exit=None)¶
- Parameters:
tensor (torch.Tensor)
is_exit (torch.Tensor | None)
- __getitem__(idx)¶
- Return type:
- __len__()¶
- Return type:
int
- exit_action¶
- is_exit¶
- tensor¶
- class bench_get_scores_all._ModDBDummyEstimator(log_action, log_exit)¶
- Parameters:
log_action (torch.Tensor)
log_exit (torch.Tensor)
- __call__(states, conditions=None)¶
- Parameters:
states (_ModDBDummyStates)
- _log_action¶
- _log_exit¶
- to_probability_distribution(states, module_output=None)¶
- Parameters:
states (_ModDBDummyStates)
- class bench_get_scores_all._ModDBDummyStates(tensor, is_sink_state=None)¶
- Parameters:
tensor (torch.Tensor)
is_sink_state (torch.Tensor | None)
- __getitem__(idx)¶
- Return type:
- __len__()¶
- Return type:
int
- property device: torch.device¶
- Return type:
torch.device
- is_sink_state¶
- tensor¶
- class bench_get_scores_all._ModDBDummyTransitions(states, next_states, actions, all_log_rewards, is_backward=False, log_probs=None, has_log_probs=False, conditions=None)¶
- Parameters:
states (_ModDBDummyStates)
next_states (_ModDBDummyStates)
actions (_ModDBDummyActions)
all_log_rewards (torch.Tensor)
is_backward (bool)
log_probs (torch.Tensor | None)
has_log_probs (bool)
conditions (torch.Tensor | None)
- __getitem__(idx)¶
- Return type:
- __len__()¶
- Return type:
int
- actions¶
- all_log_rewards¶
- conditions = None¶
- device¶
- has_log_probs = False¶
- is_backward = False¶
- log_probs = None¶
- n_transitions¶
- next_states¶
- states¶
- class bench_get_scores_all._ModDBFakeDist(log_action, log_exit)¶
- Parameters:
log_action (torch.Tensor)
log_exit (torch.Tensor)
- _log_action¶
- _log_exit¶
- log_prob(action_tensor)¶
- Parameters:
action_tensor (torch.Tensor)
- Return type:
torch.Tensor
- class bench_get_scores_all._SubTBDummyTrajectories(terminating_idx, max_length)¶
- Parameters:
terminating_idx (torch.Tensor)
max_length (int)
- __len__()¶
- Return type:
int
- batch_size¶
- max_length¶
- terminating_idx¶
- class bench_get_scores_all._TBTrajectoriesStub(log_rewards, conditions=None)¶
- Parameters:
log_rewards (torch.Tensor)
conditions (torch.Tensor | None)
- _log_rewards¶
- batch_size¶
- conditions = None¶
- property log_rewards: torch.Tensor¶
- Return type:
torch.Tensor
- bench_get_scores_all._build_lpv(T, N, device, dtype)¶
- Parameters:
T (int)
N (int)
device (torch.device)
dtype (torch.dtype)
- Return type:
Tuple[gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]
- bench_get_scores_all._build_tb(T, N, device, dtype)¶
- Parameters:
T (int)
N (int)
device (torch.device)
dtype (torch.dtype)
- Return type:
Tuple[gfn.gflownet.trajectory_balance.TBGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]
- bench_get_scores_all._db_build_model_and_data(n_transitions, seed=0, device='cpu', forward_looking=False)¶
- Parameters:
n_transitions (int)
seed (int)
device (str | torch.device)
forward_looking (bool)
- Return type:
Tuple[gfn.gflownet.detailed_balance.DBGFlowNet, _DBDummyEnv, _DBDummyTransitions]
- bench_get_scores_all._db_original_get_scores(model, env, transitions, recalculate_all_logprobs=True)¶
- Parameters:
env (_DBDummyEnv)
transitions (_DBDummyTransitions)
recalculate_all_logprobs (bool)
- Return type:
torch.Tensor
- bench_get_scores_all._format_ms(value)¶
- Parameters:
value (float | None)
- Return type:
str
- bench_get_scores_all._lpv_original_loss(model, trajectories, log_pf, log_pb)¶
- Parameters:
log_pf (torch.Tensor)
log_pb (torch.Tensor)
- bench_get_scores_all._maybe_compile(fn, enabled)¶
- Parameters:
fn (Callable[Ellipsis, Any] | None)
enabled (bool)
- Return type:
Callable[Ellipsis, Any] | None
- bench_get_scores_all._moddb_build_model_and_data(n_transitions, seed=0, device='cpu')¶
- Parameters:
n_transitions (int)
seed (int)
device (str | torch.device)
- Return type:
Tuple[gfn.gflownet.detailed_balance.ModifiedDBGFlowNet, _ModDBDummyTransitions]
- bench_get_scores_all._moddb_original_get_scores(model, transitions, recalculate_all_logprobs=True)¶
- Parameters:
transitions (_ModDBDummyTransitions)
recalculate_all_logprobs (bool)
- Return type:
torch.Tensor
- bench_get_scores_all._parse_args()¶
- Return type:
argparse.Namespace
- bench_get_scores_all._run_db(sizes, device, repeat, compile_enabled, forward_looking)¶
- Parameters:
sizes (Iterable[int])
device (torch.device)
repeat (int)
compile_enabled (bool)
forward_looking (bool)
- bench_get_scores_all._run_moddb(sizes, device, repeat, compile_enabled)¶
- Parameters:
sizes (Iterable[int])
device (torch.device)
repeat (int)
compile_enabled (bool)
- bench_get_scores_all._run_subtb(sizes, device, repeat, compile_enabled)¶
- Parameters:
sizes (Iterable[Tuple[int, int]])
device (torch.device)
repeat (int)
compile_enabled (bool)
- bench_get_scores_all._run_tb_or_lpv(variant, sizes, T, device, dtype, repeat, compile_enabled)¶
- Parameters:
variant (str)
sizes (Iterable[int])
T (int)
device (torch.device)
dtype (torch.dtype)
repeat (int)
compile_enabled (bool)
- bench_get_scores_all._run_with_compile_variants(eager_fn, compile_enabled)¶
- Parameters:
eager_fn (Callable[[], Any])
compile_enabled (bool)
- Return type:
tuple[float, float | None]
- bench_get_scores_all._scale_int(value, scale)¶
- Parameters:
value (int)
scale (float)
- Return type:
int
- bench_get_scores_all._scale_pair(pair, scale)¶
- Parameters:
pair (Tuple[int, int])
scale (float)
- Return type:
Tuple[int, int]
- bench_get_scores_all._select_dtype(name, device)¶
- Parameters:
name (str)
device (torch.device)
- Return type:
torch.dtype
- bench_get_scores_all._subtb_build_model_and_data(max_len, n_traj, seed=0, device='cpu')¶
- Parameters:
max_len (int)
n_traj (int)
seed (int)
device (str | torch.device)
- Return type:
Tuple[gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet, _SubTBDummyTrajectories, list[torch.Tensor], list[torch.Tensor]]
- bench_get_scores_all._subtb_original_get_scores(model, env, trajectories)¶
- Parameters:
- Return type:
Tuple[list[torch.Tensor], list[torch.Tensor]]
- bench_get_scores_all._tb_original_get_scores(model, trajectories, log_pf, log_pb)¶
- Parameters:
log_pf (torch.Tensor)
log_pb (torch.Tensor)
- bench_get_scores_all._tb_original_loss(model, trajectories, log_pf, log_pb)¶
- Parameters:
log_pf (torch.Tensor)
log_pb (torch.Tensor)
- bench_get_scores_all._time_fn(fn)¶
- Parameters:
fn (Callable[[], Any])
- Return type:
float
- bench_get_scores_all.main()¶