spawn_policy¶
Attributes¶
Classes¶
Asynchronous selective averaging with background, non-blocking comms. |
|
Asynchronous selective averaging version 2, uses mpi one-sided comms to get the |
|
Asynchronous selective averaging version 2, uses mpi one-sided comms to get the |
|
Standard model averaging across all ranks every N iterations. |
|
Standard model averaging across all ranks every N iterations. |
|
Helper class that provides a standard way to create an ABC using |
Functions¶
|
Return the set of globally worst ranks to replace. |
Module Contents¶
- class spawn_policy.AsyncSelectiveAveragingPolicy(model_builder, average_every, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, poll_interval_s=0.01, threshold=None, cooldown=200, timing=None)¶
Bases:
SpawnPolicyAsynchronous selective averaging with background, non-blocking comms.
Each cadence, ranks send metrics to rank 0 via isend. Rank 0 aggregates when it has all metrics, decides a replacement set and instructs donors and replacers with point-to-point messages. Donors stream parameters to replacers. Replacers aggregate in a background thread; the main thread applies the averaged weights at the next safe call, without barriers.
- Parameters:
model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])
average_every (int)
replacement_ratio (float)
averaging_strategy (str)
momentum (float)
poll_interval_s (float)
threshold (Optional[float])
cooldown (int)
timing (Optional[dict])
- _OP_NONE = 0¶
- _OP_ROLE_DONOR = 1¶
- _OP_ROLE_REPLACER = 2¶
- _TAG_CONTROL = 7002¶
- _TAG_METRIC = 7001¶
- _TAG_PARAM_BASE = 8000¶
- __call__(iteration, model, optimizer, local_metric=None, group=dist.group.WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
optimizer (torch.optim.Optimizer)
local_metric (Optional[float])
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- _background_loop()¶
- Return type:
None
- _bg_thread: threading.Thread | None = None¶
- static _compute_averaging_weights(all_metrics, ranks_to_average, averaging_strategy)¶
- Parameters:
all_metrics (torch.Tensor)
ranks_to_average (Set[int])
averaging_strategy (str)
- Return type:
Optional[torch.Tensor]
- _control_role_buf: torch.Tensor | None = None¶
- _control_work: torch.distributed.Work | None = None¶
- static _determine_ranks_for_averaging(all_metrics, world_size, replacement_ratio, averaging_strategy)¶
- Parameters:
all_metrics (torch.Tensor)
world_size (int)
replacement_ratio (float)
averaging_strategy (str)
- Return type:
Tuple[Set[int], Set[int]]
- _ensure_initialized(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _handle_donor(iteration, replacers)¶
- Parameters:
iteration (int)
replacers (List[int])
- Return type:
None
- _handle_replacer(iteration, donors, weights)¶
- Parameters:
iteration (int)
donors (List[int])
weights (Optional[torch.Tensor])
- Return type:
None
- _initialized = False¶
- _last_iter_sent: int = -1¶
- _last_trigger_iter: int = -200¶
- _model: gfn.gflownet.base.GFlowNet | None = None¶
- _model_builder¶
- _new_weights: Dict[str, torch.Tensor] | None = None¶
- _pending_lock¶
- _rank0_buckets: Dict[int, Dict[int, float]]¶
- _rank0_dispatch_controls(iteration, ranks_to_replace, ranks_to_average, weights)¶
- Parameters:
iteration (int)
ranks_to_replace (Set[int])
ranks_to_average (Set[int])
weights (Optional[torch.Tensor])
- Return type:
None
- _rank0_metric_buffers: Dict[int, torch.Tensor]¶
- _rank0_metric_handles: Dict[int, torch.distributed.Work]¶
- _rank0_poll_metrics()¶
- Return type:
None
- _rank0_post_metric_recvs()¶
- Return type:
None
- _rank0_record_metric(iteration, rank, metric, world_size)¶
- Parameters:
iteration (int)
rank (int)
metric (float)
world_size (int)
- Return type:
None
- _shutdown¶
- static _validate_params(replacement_ratio, averaging_strategy, momentum, threshold, cooldown)¶
- Parameters:
replacement_ratio (float)
averaging_strategy (str)
momentum (float)
threshold (Optional[float])
cooldown (int)
- Return type:
None
- averaging_strategy = ''¶
- cooldown: int = 200¶
- momentum¶
- poll_interval_s¶
- replacement_ratio¶
- shutdown()¶
- Return type:
None
- threshold: float | None = None¶
- class spawn_policy.AsyncSelectiveAveragingPolicympi4pyFast(model_builder, model, average_every, threshold_metric=0.0, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, age_range=(50, 150), replacement_mode='age', group=MPI.COMM_WORLD)¶
Bases:
SpawnPolicyAsynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks.
- Parameters:
model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])
model (gfn.gflownet.base.GFlowNet)
average_every (int)
threshold_metric (float)
replacement_ratio (float)
averaging_strategy (str)
momentum (float)
age_range (Tuple[int, int])
replacement_mode (str)
group (mpi4py.MPI.Comm)
- __call__(iteration, model, optimizer, local_metric, expose_params=True, group=MPI.COMM_WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
optimizer (torch.optim.Optimizer)
local_metric (float)
expose_params (bool)
group (mpi4py.MPI.Comm)
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- _copy_model_params_to_buf(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _count = 0¶
- _ensure_initialized(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _expose = False¶
- _expose_model_parameters(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _get_donors(n, k, d)¶
- Return type:
List[int]
- _get_model_params_from_donors(donors, layer_name, f)¶
- Parameters:
donors (List[int])
- Return type:
torch.Tensor
- _is_worst_rank_this_iteration(iteration, local_metric)¶
- Parameters:
iteration (int)
local_metric (float)
- Return type:
Tuple[bool, Optional[List[float]]]
- _model: gfn.gflownet.base.GFlowNet | None = None¶
- _model_builder¶
- _should_replace(iteration, local_metric)¶
- Parameters:
iteration (int)
local_metric (float)
- Return type:
Tuple[bool, Optional[List[float]]]
- age = 0¶
- age_range = (50, 150)¶
- agents_killed = 0¶
- averaging_ranks = 0¶
- averaging_strategy = ''¶
- capture_comm(name, size)¶
- Parameters:
name (str)
size (int)
- Return type:
None
- comm_size¶
- debug_mode = False¶
- is_agent_dying(local_metric, threshold_metric, check_policy=0)¶
- Parameters:
local_metric (float)
threshold_metric (float)
- Return type:
bool
- max_age¶
- momentum¶
- myrank¶
- num_replacements = 0¶
- print_stats()¶
- Return type:
None
- print_time()¶
- Return type:
None
- replacement_mode = ''¶
- replacement_ratio¶
- reset_age()¶
- Return type:
None
- shutdown()¶
- Return type:
None
- stats¶
- threshold_metric¶
- timing¶
- total_iterations = 0¶
- train_comm_group¶
- class spawn_policy.AsyncSelectiveAveragingPolicympi4pyGeneral(model_builder, model, average_every, threshold_metric=0.0, replacement_ratio=0.2, averaging_strategy='mean', momentum=0.0, poll_interval_s=0.01, age_range=(50, 150), replacement_mode='age', group=MPI.COMM_WORLD)¶
Bases:
SpawnPolicyAsynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks.
- Parameters:
model_builder (Callable[[], Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer]])
model (gfn.gflownet.base.GFlowNet)
average_every (int)
threshold_metric (float)
replacement_ratio (float)
averaging_strategy (str)
momentum (float)
poll_interval_s (float)
age_range (Tuple[int, int])
replacement_mode (str)
group (mpi4py.MPI.Comm)
- __call__(iteration, model, optimizer, local_metric, expose_params=True, group=MPI.COMM_WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
optimizer (torch.optim.Optimizer)
local_metric (float)
expose_params (bool)
group (mpi4py.MPI.Comm)
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- _average_received_params()¶
- Return type:
Dict[str, torch.Tensor]
- _copy_model_params_to_buf(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _count = 0¶
- _ensure_initialized(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _expose = False¶
- _expose_model_parameters(model)¶
- Parameters:
model (gfn.gflownet.base.GFlowNet)
- Return type:
None
- _get_donors(n, k, d)¶
- Return type:
List[int]
- _get_model_params_from_donors(donors, layer_name, f)¶
- Parameters:
donors (List[int])
- Return type:
Dict[str, torch.Tensor]
- _is_worst_rank_this_iteration(iteration, local_metric)¶
- Parameters:
iteration (int)
local_metric (float)
- Return type:
Tuple[bool, Optional[List[float]]]
- _model: gfn.gflownet.base.GFlowNet | None = None¶
- _model_builder¶
- _should_replace(iteration, local_metric)¶
- Parameters:
iteration (int)
local_metric (float)
- Return type:
Tuple[bool, Optional[List[float]]]
- age = 0¶
- age_range = (50, 150)¶
- agents_killed = 0¶
- averaging_ranks = 0¶
- averaging_strategy = ''¶
- capture_comm(name, size)¶
- Parameters:
name (str)
size (int)
- Return type:
None
- comm_size¶
- debug_mode = False¶
- is_agent_dying(local_metric, threshold_metric, check_agent=0)¶
- Parameters:
local_metric (float)
threshold_metric (float)
- Return type:
bool
- max_age¶
- momentum¶
- myrank¶
- num_replacements = 0¶
- print_stats()¶
- Return type:
None
- print_time()¶
- Return type:
None
- replacement_mode = ''¶
- replacement_ratio¶
- reset_age()¶
- Return type:
None
- shutdown()¶
- Return type:
None
- stats¶
- threshold_metric¶
- timing¶
- total_iterations = 0¶
- train_comm_group¶
- class spawn_policy.AverageAllPolicy(average_every)¶
Bases:
SpawnPolicyStandard model averaging across all ranks every N iterations.
- Parameters:
average_every (int)
- __call__(iteration, model, optimizer, local_metric=None, group=dist.group.WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
optimizer (torch.optim.Optimizer)
local_metric (Optional[float])
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- class spawn_policy.AverageAllPolicympi4py(average_every)¶
Bases:
SpawnPolicyStandard model averaging across all ranks every N iterations.
- Parameters:
average_every (int)
- __call__(iteration, model, optimizer, local_metric=None, group=MPI.COMM_WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
optimizer (torch.optim.Optimizer)
local_metric (Optional[float])
group (mpi4py.MPI.Comm)
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- class spawn_policy.SpawnPolicy(average_every)¶
Bases:
abc.ABCHelper class that provides a standard way to create an ABC using inheritance.
- Parameters:
average_every (int)
- abstract __call__(iteration, model, local_metric=None, group=dist.group.WORLD)¶
Possibly perform a spawn/averaging step on this iteration.
- Parameters:
iteration (int)
model (gfn.gflownet.base.GFlowNet)
local_metric (Optional[float])
- Return type:
Tuple[gfn.gflownet.base.GFlowNet, torch.optim.Optimizer, dict]
- average_every¶
- spawn_policy._compute_worst_ranks_by_ratio(metrics, replacement_ratio)¶
Return the set of globally worst ranks to replace.
- Parameters:
metrics (List[float])
replacement_ratio (float)
- Return type:
Set[int]
- spawn_policy.logger¶