spawn_policy ============ .. py:module:: spawn_policy Attributes ---------- .. autoapisummary:: spawn_policy.logger Classes ------- .. autoapisummary:: spawn_policy.AsyncSelectiveAveragingPolicy spawn_policy.AsyncSelectiveAveragingPolicympi4pyFast spawn_policy.AsyncSelectiveAveragingPolicympi4pyGeneral spawn_policy.AverageAllPolicy spawn_policy.AverageAllPolicympi4py spawn_policy.SpawnPolicy Functions --------- .. autoapisummary:: spawn_policy._compute_worst_ranks_by_ratio Module Contents --------------- .. py:class:: AsyncSelectiveAveragingPolicy(model_builder, average_every, replacement_ratio = 0.2, averaging_strategy = 'mean', momentum = 0.0, poll_interval_s = 0.01, threshold = None, cooldown = 200, timing = None) Bases: :py:obj:`SpawnPolicy` Asynchronous selective averaging with background, non-blocking comms. Each cadence, ranks send metrics to rank 0 via isend. Rank 0 aggregates when it has all metrics, decides a replacement set and instructs donors and replacers with point-to-point messages. Donors stream parameters to replacers. Replacers aggregate in a background thread; the main thread applies the averaged weights at the next safe call, without barriers. .. py:attribute:: _OP_NONE :value: 0 .. py:attribute:: _OP_ROLE_DONOR :value: 1 .. py:attribute:: _OP_ROLE_REPLACER :value: 2 .. py:attribute:: _TAG_CONTROL :value: 7002 .. py:attribute:: _TAG_METRIC :value: 7001 .. py:attribute:: _TAG_PARAM_BASE :value: 8000 .. py:method:: __call__(iteration, model, optimizer, local_metric = None, group=dist.group.WORLD) Possibly perform a spawn/averaging step on this iteration. .. py:method:: _background_loop() .. py:attribute:: _bg_thread :type: Optional[threading.Thread] :value: None .. py:method:: _compute_averaging_weights(all_metrics, ranks_to_average, averaging_strategy) :staticmethod: .. py:attribute:: _control_role_buf :type: Optional[torch.Tensor] :value: None .. py:attribute:: _control_work :type: Optional[torch.distributed.Work] :value: None .. py:method:: _determine_ranks_for_averaging(all_metrics, world_size, replacement_ratio, averaging_strategy) :staticmethod: .. py:method:: _ensure_initialized(model) .. py:method:: _handle_donor(iteration, replacers) .. py:method:: _handle_replacer(iteration, donors, weights) .. py:attribute:: _initialized :value: False .. py:attribute:: _last_iter_sent :type: int :value: -1 .. py:attribute:: _last_trigger_iter :type: int :value: -200 .. py:attribute:: _model :type: Optional[gfn.gflownet.base.GFlowNet] :value: None .. py:attribute:: _model_builder .. py:attribute:: _new_weights :type: Optional[Dict[str, torch.Tensor]] :value: None .. py:attribute:: _pending_lock .. py:attribute:: _rank0_buckets :type: Dict[int, Dict[int, float]] .. py:method:: _rank0_dispatch_controls(iteration, ranks_to_replace, ranks_to_average, weights) .. py:attribute:: _rank0_metric_buffers :type: Dict[int, torch.Tensor] .. py:attribute:: _rank0_metric_handles :type: Dict[int, torch.distributed.Work] .. py:method:: _rank0_poll_metrics() .. py:method:: _rank0_post_metric_recvs() .. py:method:: _rank0_record_metric(iteration, rank, metric, world_size) .. py:attribute:: _shutdown .. py:method:: _validate_params(replacement_ratio, averaging_strategy, momentum, threshold, cooldown) :staticmethod: .. py:attribute:: averaging_strategy :value: '' .. py:attribute:: cooldown :type: int :value: 200 .. py:attribute:: momentum .. py:attribute:: poll_interval_s .. py:attribute:: replacement_ratio .. py:method:: shutdown() .. py:attribute:: threshold :type: Optional[float] :value: None .. py:class:: AsyncSelectiveAveragingPolicympi4pyFast(model_builder, model, average_every, threshold_metric = 0.0, replacement_ratio = 0.2, averaging_strategy = 'mean', momentum = 0.0, age_range = (50, 150), replacement_mode = 'age', group = MPI.COMM_WORLD) Bases: :py:obj:`SpawnPolicy` Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks. .. py:method:: __call__(iteration, model, optimizer, local_metric, expose_params = True, group = MPI.COMM_WORLD) Possibly perform a spawn/averaging step on this iteration. .. py:method:: _copy_model_params_to_buf(model) .. py:attribute:: _count :value: 0 .. py:method:: _ensure_initialized(model) .. py:attribute:: _expose :value: False .. py:method:: _expose_model_parameters(model) .. py:method:: _get_donors(n, k, d) .. py:method:: _get_model_params_from_donors(donors, layer_name, f) .. py:method:: _is_worst_rank_this_iteration(iteration, local_metric) .. py:attribute:: _model :type: Optional[gfn.gflownet.base.GFlowNet] :value: None .. py:attribute:: _model_builder .. py:method:: _should_replace(iteration, local_metric) .. py:attribute:: age :value: 0 .. py:attribute:: age_range :value: (50, 150) .. py:attribute:: agents_killed :value: 0 .. py:attribute:: averaging_ranks :value: 0 .. py:attribute:: averaging_strategy :value: '' .. py:method:: capture_comm(name, size) .. py:attribute:: comm_size .. py:attribute:: debug_mode :value: False .. py:method:: is_agent_dying(local_metric, threshold_metric, check_policy=0) .. py:attribute:: max_age .. py:attribute:: momentum .. py:attribute:: myrank .. py:attribute:: num_replacements :value: 0 .. py:method:: print_stats() .. py:method:: print_time() .. py:attribute:: replacement_mode :value: '' .. py:attribute:: replacement_ratio .. py:method:: reset_age() .. py:method:: shutdown() .. py:attribute:: stats .. py:attribute:: threshold_metric .. py:attribute:: timing .. py:attribute:: total_iterations :value: 0 .. py:attribute:: train_comm_group .. py:class:: AsyncSelectiveAveragingPolicympi4pyGeneral(model_builder, model, average_every, threshold_metric = 0.0, replacement_ratio = 0.2, averaging_strategy = 'mean', momentum = 0.0, poll_interval_s = 0.01, age_range = (50, 150), replacement_mode = 'age', group = MPI.COMM_WORLD) Bases: :py:obj:`SpawnPolicy` Asynchronous selective averaging version 2, uses mpi one-sided comms to get the selectively averaged parameters from a random set of ranks. .. py:method:: __call__(iteration, model, optimizer, local_metric, expose_params = True, group = MPI.COMM_WORLD) Possibly perform a spawn/averaging step on this iteration. .. py:method:: _average_received_params() .. py:method:: _copy_model_params_to_buf(model) .. py:attribute:: _count :value: 0 .. py:method:: _ensure_initialized(model) .. py:attribute:: _expose :value: False .. py:method:: _expose_model_parameters(model) .. py:method:: _get_donors(n, k, d) .. py:method:: _get_model_params_from_donors(donors, layer_name, f) .. py:method:: _is_worst_rank_this_iteration(iteration, local_metric) .. py:attribute:: _model :type: Optional[gfn.gflownet.base.GFlowNet] :value: None .. py:attribute:: _model_builder .. py:method:: _should_replace(iteration, local_metric) .. py:attribute:: age :value: 0 .. py:attribute:: age_range :value: (50, 150) .. py:attribute:: agents_killed :value: 0 .. py:attribute:: averaging_ranks :value: 0 .. py:attribute:: averaging_strategy :value: '' .. py:method:: capture_comm(name, size) .. py:attribute:: comm_size .. py:attribute:: debug_mode :value: False .. py:method:: is_agent_dying(local_metric, threshold_metric, check_agent=0) .. py:attribute:: max_age .. py:attribute:: momentum .. py:attribute:: myrank .. py:attribute:: num_replacements :value: 0 .. py:method:: print_stats() .. py:method:: print_time() .. py:attribute:: replacement_mode :value: '' .. py:attribute:: replacement_ratio .. py:method:: reset_age() .. py:method:: shutdown() .. py:attribute:: stats .. py:attribute:: threshold_metric .. py:attribute:: timing .. py:attribute:: total_iterations :value: 0 .. py:attribute:: train_comm_group .. py:class:: AverageAllPolicy(average_every) Bases: :py:obj:`SpawnPolicy` Standard model averaging across all ranks every N iterations. .. py:method:: __call__(iteration, model, optimizer, local_metric = None, group=dist.group.WORLD) Possibly perform a spawn/averaging step on this iteration. .. py:class:: AverageAllPolicympi4py(average_every) Bases: :py:obj:`SpawnPolicy` Standard model averaging across all ranks every N iterations. .. py:method:: __call__(iteration, model, optimizer, local_metric = None, group = MPI.COMM_WORLD) Possibly perform a spawn/averaging step on this iteration. .. py:class:: SpawnPolicy(average_every) Bases: :py:obj:`abc.ABC` Helper class that provides a standard way to create an ABC using inheritance. .. py:method:: __call__(iteration, model, local_metric = None, group=dist.group.WORLD) :abstractmethod: Possibly perform a spawn/averaging step on this iteration. .. py:attribute:: average_every .. py:function:: _compute_worst_ranks_by_ratio(metrics, replacement_ratio) Return the set of globally worst ranks to replace. .. py:data:: logger