bench_get_scores_all

Unified micro-benchmark runner for GFlowNet losses/get_scores.

Runs, in order: - Trajectory Balance (TB) loss - Log Partition Variance (LPV) loss - Sub-trajectory Balance (SubTB) get_scores - Detailed Balance (DB) get_scores - Modified Detailed Balance (ModDB) get_scores

For each loss, the script reports four timings:

original (baseline, frozen copy) original+compile (torch.compile applied to the baseline function) current (eager) current+compile (torch.compile applied to the current function)

Two speedups are printed: current/original and current+compile/original.

All benchmarks use embedded base sizes, scaled by a single –size-scale multiplier. Correctness is checked once per size before timing and is excluded from the timing loops.

Classes

_DBDummyActions

_DBDummyEnv

_DBDummyStates

_DBDummyTransitions

_ModDBDummyActions

_ModDBDummyEstimator

_ModDBDummyStates

_ModDBDummyTransitions

_ModDBFakeDist

_SubTBDummyTrajectories

_TBTrajectoriesStub

Functions

_build_lpv(T, N, device, dtype)

_build_tb(T, N, device, dtype)

_db_build_model_and_data(n_transitions[, seed, ...])

_db_original_get_scores(model, env, transitions[, ...])

_format_ms(value)

_lpv_original_loss(model, trajectories, log_pf, log_pb)

_maybe_compile(fn, enabled)

_moddb_build_model_and_data(n_transitions[, seed, device])

_moddb_original_get_scores(model, transitions[, ...])

_parse_args()

_run_db(sizes, device, repeat, compile_enabled, ...)

_run_moddb(sizes, device, repeat, compile_enabled)

_run_subtb(sizes, device, repeat, compile_enabled)

_run_tb_or_lpv(variant, sizes, T, device, dtype, ...)

_run_with_compile_variants(eager_fn, compile_enabled)

_scale_int(value, scale)

_scale_pair(pair, scale)

_select_dtype(name, device)

_subtb_build_model_and_data(max_len, n_traj[, seed, ...])

_subtb_original_get_scores(model, env, trajectories)

_tb_original_get_scores(model, trajectories, log_pf, ...)

_tb_original_loss(model, trajectories, log_pf, log_pb)

_time_fn(fn)

main()

Module Contents

class bench_get_scores_all._DBDummyActions(tensor)
Parameters:

tensor (torch.Tensor)

__getitem__(idx)
Return type:

_DBDummyActions

__len__()
Return type:

int

property batch_shape: torch.Size
Return type:

torch.Size

exit_action
is_exit
tensor
class bench_get_scores_all._DBDummyEnv(log_reward_fn)
Parameters:

log_reward_fn (Callable[[Any, Any | None], torch.Tensor])

_log_reward_fn
log_reward(states, conditions=None)
Parameters:
  • states (Any)

  • conditions (Any | None)

Return type:

torch.Tensor

class bench_get_scores_all._DBDummyStates(tensor, is_sink_state=None)
Parameters:
  • tensor (torch.Tensor)

  • is_sink_state (torch.Tensor | None)

__getitem__(idx)
Return type:

_DBDummyStates

__len__()
Return type:

int

property batch_shape: torch.Size
Return type:

torch.Size

property device: torch.device
Return type:

torch.device

is_sink_state
tensor
class bench_get_scores_all._DBDummyTransitions(states, next_states, actions, is_terminating, log_rewards, conditions=None)
Parameters:
__len__()
Return type:

int

actions
conditions = None
device
is_backward = False
is_terminating
log_rewards
n_transitions
next_states
states
class bench_get_scores_all._ModDBDummyActions(tensor, is_exit=None)
Parameters:
  • tensor (torch.Tensor)

  • is_exit (torch.Tensor | None)

__getitem__(idx)
Return type:

_ModDBDummyActions

__len__()
Return type:

int

exit_action
is_exit
tensor
class bench_get_scores_all._ModDBDummyEstimator(log_action, log_exit)
Parameters:
  • log_action (torch.Tensor)

  • log_exit (torch.Tensor)

__call__(states, conditions=None)
Parameters:

states (_ModDBDummyStates)

_log_action
_log_exit
to_probability_distribution(states, module_output=None)
Parameters:

states (_ModDBDummyStates)

class bench_get_scores_all._ModDBDummyStates(tensor, is_sink_state=None)
Parameters:
  • tensor (torch.Tensor)

  • is_sink_state (torch.Tensor | None)

__getitem__(idx)
Return type:

_ModDBDummyStates

__len__()
Return type:

int

property device: torch.device
Return type:

torch.device

is_sink_state
tensor
class bench_get_scores_all._ModDBDummyTransitions(states, next_states, actions, all_log_rewards, is_backward=False, log_probs=None, has_log_probs=False, conditions=None)
Parameters:
__getitem__(idx)
Return type:

_ModDBDummyTransitions

__len__()
Return type:

int

actions
all_log_rewards
conditions = None
device
has_log_probs = False
is_backward = False
log_probs = None
n_transitions
next_states
states
class bench_get_scores_all._ModDBFakeDist(log_action, log_exit)
Parameters:
  • log_action (torch.Tensor)

  • log_exit (torch.Tensor)

_log_action
_log_exit
log_prob(action_tensor)
Parameters:

action_tensor (torch.Tensor)

Return type:

torch.Tensor

class bench_get_scores_all._SubTBDummyTrajectories(terminating_idx, max_length)
Parameters:
  • terminating_idx (torch.Tensor)

  • max_length (int)

__len__()
Return type:

int

batch_size
max_length
terminating_idx
class bench_get_scores_all._TBTrajectoriesStub(log_rewards, conditions=None)
Parameters:
  • log_rewards (torch.Tensor)

  • conditions (torch.Tensor | None)

_log_rewards
batch_size
conditions = None
property log_rewards: torch.Tensor
Return type:

torch.Tensor

bench_get_scores_all._build_lpv(T, N, device, dtype)
Parameters:
  • T (int)

  • N (int)

  • device (torch.device)

  • dtype (torch.dtype)

Return type:

Tuple[gfn.gflownet.trajectory_balance.LogPartitionVarianceGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]

bench_get_scores_all._build_tb(T, N, device, dtype)
Parameters:
  • T (int)

  • N (int)

  • device (torch.device)

  • dtype (torch.dtype)

Return type:

Tuple[gfn.gflownet.trajectory_balance.TBGFlowNet, _TBTrajectoriesStub, torch.Tensor, torch.Tensor]

bench_get_scores_all._db_build_model_and_data(n_transitions, seed=0, device='cpu', forward_looking=False)
Parameters:
  • n_transitions (int)

  • seed (int)

  • device (str | torch.device)

  • forward_looking (bool)

Return type:

Tuple[gfn.gflownet.detailed_balance.DBGFlowNet, _DBDummyEnv, _DBDummyTransitions]

bench_get_scores_all._db_original_get_scores(model, env, transitions, recalculate_all_logprobs=True)
Parameters:
Return type:

torch.Tensor

bench_get_scores_all._format_ms(value)
Parameters:

value (float | None)

Return type:

str

bench_get_scores_all._lpv_original_loss(model, trajectories, log_pf, log_pb)
Parameters:
  • log_pf (torch.Tensor)

  • log_pb (torch.Tensor)

bench_get_scores_all._maybe_compile(fn, enabled)
Parameters:
  • fn (Callable[Ellipsis, Any] | None)

  • enabled (bool)

Return type:

Callable[Ellipsis, Any] | None

bench_get_scores_all._moddb_build_model_and_data(n_transitions, seed=0, device='cpu')
Parameters:
  • n_transitions (int)

  • seed (int)

  • device (str | torch.device)

Return type:

Tuple[gfn.gflownet.detailed_balance.ModifiedDBGFlowNet, _ModDBDummyTransitions]

bench_get_scores_all._moddb_original_get_scores(model, transitions, recalculate_all_logprobs=True)
Parameters:
Return type:

torch.Tensor

bench_get_scores_all._parse_args()
Return type:

argparse.Namespace

bench_get_scores_all._run_db(sizes, device, repeat, compile_enabled, forward_looking)
Parameters:
  • sizes (Iterable[int])

  • device (torch.device)

  • repeat (int)

  • compile_enabled (bool)

  • forward_looking (bool)

bench_get_scores_all._run_moddb(sizes, device, repeat, compile_enabled)
Parameters:
  • sizes (Iterable[int])

  • device (torch.device)

  • repeat (int)

  • compile_enabled (bool)

bench_get_scores_all._run_subtb(sizes, device, repeat, compile_enabled)
Parameters:
  • sizes (Iterable[Tuple[int, int]])

  • device (torch.device)

  • repeat (int)

  • compile_enabled (bool)

bench_get_scores_all._run_tb_or_lpv(variant, sizes, T, device, dtype, repeat, compile_enabled)
Parameters:
  • variant (str)

  • sizes (Iterable[int])

  • T (int)

  • device (torch.device)

  • dtype (torch.dtype)

  • repeat (int)

  • compile_enabled (bool)

bench_get_scores_all._run_with_compile_variants(eager_fn, compile_enabled)
Parameters:
  • eager_fn (Callable[[], Any])

  • compile_enabled (bool)

Return type:

tuple[float, float | None]

bench_get_scores_all._scale_int(value, scale)
Parameters:
  • value (int)

  • scale (float)

Return type:

int

bench_get_scores_all._scale_pair(pair, scale)
Parameters:
  • pair (Tuple[int, int])

  • scale (float)

Return type:

Tuple[int, int]

bench_get_scores_all._select_dtype(name, device)
Parameters:
  • name (str)

  • device (torch.device)

Return type:

torch.dtype

bench_get_scores_all._subtb_build_model_and_data(max_len, n_traj, seed=0, device='cpu')
Parameters:
  • max_len (int)

  • n_traj (int)

  • seed (int)

  • device (str | torch.device)

Return type:

Tuple[gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet, _SubTBDummyTrajectories, list[torch.Tensor], list[torch.Tensor]]

bench_get_scores_all._subtb_original_get_scores(model, env, trajectories)
Parameters:

model (gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet)

Return type:

Tuple[list[torch.Tensor], list[torch.Tensor]]

bench_get_scores_all._tb_original_get_scores(model, trajectories, log_pf, log_pb)
Parameters:
  • log_pf (torch.Tensor)

  • log_pb (torch.Tensor)

bench_get_scores_all._tb_original_loss(model, trajectories, log_pf, log_pb)
Parameters:
  • log_pf (torch.Tensor)

  • log_pb (torch.Tensor)

bench_get_scores_all._time_fn(fn)
Parameters:

fn (Callable[[], Any])

Return type:

float

bench_get_scores_all.main()