spawn_policy

Attributes

logger

Classes

AsyncSelectiveAveragingPolicy

Asynchronous selective averaging with background, non-blocking comms.

AsyncSelectiveAveragingPolicympi4pyFast

Asynchronous selective averaging version 2, uses mpi one-sided comms to get the

AsyncSelectiveAveragingPolicympi4pyGeneral

Asynchronous selective averaging version 2, uses mpi one-sided comms to get the

AverageAllPolicy

Standard model averaging across all ranks every N iterations.

AverageAllPolicympi4py

Standard model averaging across all ranks every N iterations.

SpawnPolicy

Helper class that provides a standard way to create an ABC using

Functions

_compute_worst_ranks_by_ratio(metrics, replacement_ratio)

Return the set of globally worst ranks to replace.

Module Contents

class spawn_policy.AsyncSelectiveAveragingPolicy(model_builder, average_every, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, poll_interval_s=0.01, threshold=None, cooldown=200, timing=None)

Bases: SpawnPolicy

Asynchronous selective averaging with background, non-blocking comms.

Each cadence, ranks send metrics to rank 0 via isend. Rank 0 aggregates when it has all metrics, decides a replacement set and instructs donors and replacers with point-to-point messages. Donors stream parameters to replacers. Replacers aggregate in a background thread; the main thread applies the averaged weights at the next safe call, without barriers.

Parameters:
  • model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])

  • average_every (int)

  • replacement_ratio (float)

  • averaging_strategy (str)

  • momentum (float)

  • poll_interval_s (float)

  • threshold (Optional[float])

  • cooldown (int)

  • timing (Optional[dict])

_OP_NONE = 0
_OP_ROLE_DONOR = 1
_OP_ROLE_REPLACER = 2
_TAG_CONTROL = 7002
_TAG_METRIC = 7001
_TAG_PARAM_BASE = 8000
__call__(iteration, model, optimizer, local_metric=None, group=dist.group.WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

_background_loop()
Return type:

None

_bg_thread: threading.Thread | None = None
static _compute_averaging_weights(all_metrics, ranks_to_average, averaging_strategy)
Parameters:
  • all_metrics (torch.Tensor)

  • ranks_to_average (Set[int])

  • averaging_strategy (str)

Return type:

Optional[torch.Tensor]

_control_role_buf: torch.Tensor | None = None
_control_work: torch.distributed.Work | None = None
static _determine_ranks_for_averaging(all_metrics, world_size, replacement_ratio, averaging_strategy)
Parameters:
  • all_metrics (torch.Tensor)

  • world_size (int)

  • replacement_ratio (float)

  • averaging_strategy (str)

Return type:

Tuple[Set[int], Set[int]]

_ensure_initialized(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_handle_donor(iteration, replacers)
Parameters:
  • iteration (int)

  • replacers (List[int])

Return type:

None

_handle_replacer(iteration, donors, weights)
Parameters:
  • iteration (int)

  • donors (List[int])

  • weights (Optional[torch.Tensor])

Return type:

None

_initialized = False
_last_iter_sent: int = -1
_last_trigger_iter: int = -200
_model: gfn.gflownet.base.GFlowNet | None = None
_model_builder
_new_weights: Dict[str, torch.Tensor] | None = None
_pending_lock
_rank0_buckets: Dict[int, Dict[int, float]]
_rank0_dispatch_controls(iteration, ranks_to_replace, ranks_to_average, weights)
Parameters:
  • iteration (int)

  • ranks_to_replace (Set[int])

  • ranks_to_average (Set[int])

  • weights (Optional[torch.Tensor])

Return type:

None

_rank0_metric_buffers: Dict[int, torch.Tensor]
_rank0_metric_handles: Dict[int, torch.distributed.Work]
_rank0_poll_metrics()
Return type:

None

_rank0_post_metric_recvs()
Return type:

None

_rank0_record_metric(iteration, rank, metric, world_size)
Parameters:
  • iteration (int)

  • rank (int)

  • metric (float)

  • world_size (int)

Return type:

None

_shutdown
static _validate_params(replacement_ratio, averaging_strategy, momentum, threshold, cooldown)
Parameters:
  • replacement_ratio (float)

  • averaging_strategy (str)

  • momentum (float)

  • threshold (Optional[float])

  • cooldown (int)

Return type:

None

averaging_strategy = ''
cooldown: int = 200
momentum
poll_interval_s
replacement_ratio
shutdown()
Return type:

None

threshold: float | None = None
class spawn_policy.AsyncSelectiveAveragingPolicympi4pyFast(model_builder, model, average_every, threshold_metric=0.0, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, age_range=(50, 150), replacement_mode='age', group=MPI.COMM_WORLD)

Bases: SpawnPolicy

Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks.

Parameters:
  • model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])

  • model (gfn.gflownet.base.GFlowNet)

  • average_every (int)

  • threshold_metric (float)

  • replacement_ratio (float)

  • averaging_strategy (str)

  • momentum (float)

  • age_range (Tuple[int, int])

  • replacement_mode (str)

  • group (mpi4py.MPI.Comm)

__call__(iteration, model, optimizer, local_metric, expose_params=True, group=MPI.COMM_WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
  • iteration (int)

  • model (gfn.gflownet.base.GFlowNet)

  • optimizer (torch.optim.Optimizer)

  • local_metric (float)

  • expose_params (bool)

  • group (mpi4py.MPI.Comm)

Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

_copy_model_params_to_buf(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_count = 0
_ensure_initialized(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_expose = False
_expose_model_parameters(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_get_donors(n, k, d)
Return type:

List[int]

_get_model_params_from_donors(donors, layer_name, f)
Parameters:

donors (List[int])

Return type:

torch.Tensor

_is_worst_rank_this_iteration(iteration, local_metric)
Parameters:
  • iteration (int)

  • local_metric (float)

Return type:

Tuple[bool, Optional[List[float]]]

_model: gfn.gflownet.base.GFlowNet | None = None
_model_builder
_should_replace(iteration, local_metric)
Parameters:
  • iteration (int)

  • local_metric (float)

Return type:

Tuple[bool, Optional[List[float]]]

age = 0
age_range = (50, 150)
agents_killed = 0
averaging_ranks = 0
averaging_strategy = ''
capture_comm(name, size)
Parameters:
  • name (str)

  • size (int)

Return type:

None

comm_size
debug_mode = False
is_agent_dying(local_metric, threshold_metric, check_policy=0)
Parameters:
  • local_metric (float)

  • threshold_metric (float)

Return type:

bool

max_age
momentum
myrank
num_replacements = 0
print_stats()
Return type:

None

print_time()
Return type:

None

replacement_mode = ''
replacement_ratio
reset_age()
Return type:

None

shutdown()
Return type:

None

stats
threshold_metric
timing
total_iterations = 0
train_comm_group
class spawn_policy.AsyncSelectiveAveragingPolicympi4pyGeneral(model_builder, model, average_every, threshold_metric=0.0, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, poll_interval_s=0.01, age_range=(50, 150), replacement_mode='age', group=MPI.COMM_WORLD)

Bases: SpawnPolicy

Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks.

Parameters:
  • model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])

  • model (gfn.gflownet.base.GFlowNet)

  • average_every (int)

  • threshold_metric (float)

  • replacement_ratio (float)

  • averaging_strategy (str)

  • momentum (float)

  • poll_interval_s (float)

  • age_range (Tuple[int, int])

  • replacement_mode (str)

  • group (mpi4py.MPI.Comm)

__call__(iteration, model, optimizer, local_metric, expose_params=True, group=MPI.COMM_WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
  • iteration (int)

  • model (gfn.gflownet.base.GFlowNet)

  • optimizer (torch.optim.Optimizer)

  • local_metric (float)

  • expose_params (bool)

  • group (mpi4py.MPI.Comm)

Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

_average_received_params()
Return type:

Dict[str, torch.Tensor]

_copy_model_params_to_buf(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_count = 0
_ensure_initialized(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_expose = False
_expose_model_parameters(model)
Parameters:

model (gfn.gflownet.base.GFlowNet)

Return type:

None

_get_donors(n, k, d)
Return type:

List[int]

_get_model_params_from_donors(donors, layer_name, f)
Parameters:

donors (List[int])

Return type:

Dict[str, torch.Tensor]

_is_worst_rank_this_iteration(iteration, local_metric)
Parameters:
  • iteration (int)

  • local_metric (float)

Return type:

Tuple[bool, Optional[List[float]]]

_model: gfn.gflownet.base.GFlowNet | None = None
_model_builder
_should_replace(iteration, local_metric)
Parameters:
  • iteration (int)

  • local_metric (float)

Return type:

Tuple[bool, Optional[List[float]]]

age = 0
age_range = (50, 150)
agents_killed = 0
averaging_ranks = 0
averaging_strategy = ''
capture_comm(name, size)
Parameters:
  • name (str)

  • size (int)

Return type:

None

comm_size
debug_mode = False
is_agent_dying(local_metric, threshold_metric, check_agent=0)
Parameters:
  • local_metric (float)

  • threshold_metric (float)

Return type:

bool

max_age
momentum
myrank
num_replacements = 0
print_stats()
Return type:

None

print_time()
Return type:

None

replacement_mode = ''
replacement_ratio
reset_age()
Return type:

None

shutdown()
Return type:

None

stats
threshold_metric
timing
total_iterations = 0
train_comm_group
class spawn_policy.AverageAllPolicy(average_every)

Bases: SpawnPolicy

Standard model averaging across all ranks every N iterations.

Parameters:

average_every (int)

__call__(iteration, model, optimizer, local_metric=None, group=dist.group.WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

class spawn_policy.AverageAllPolicympi4py(average_every)

Bases: SpawnPolicy

Standard model averaging across all ranks every N iterations.

Parameters:

average_every (int)

__call__(iteration, model, optimizer, local_metric=None, group=MPI.COMM_WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
  • iteration (int)

  • model (gfn.gflownet.base.GFlowNet)

  • optimizer (torch.optim.Optimizer)

  • local_metric (Optional[float])

  • group (mpi4py.MPI.Comm)

Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

class spawn_policy.SpawnPolicy(average_every)

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

Parameters:

average_every (int)

abstract __call__(iteration, model, local_metric=None, group=dist.group.WORLD)

Possibly perform a spawn/averaging step on this iteration.

Parameters:
Return type:

Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]

average_every
spawn_policy._compute_worst_ranks_by_ratio(metrics, replacement_ratio)

Return the set of globally worst ranks to replace.

Parameters:
  • metrics (List[float])

  • replacement_ratio (float)

Return type:

Set[int]

spawn_policy.logger