gfn.utils.distributed ===================== .. py:module:: gfn.utils.distributed Attributes ---------- .. autoapisummary:: gfn.utils.distributed.logger Classes ------- .. autoapisummary:: gfn.utils.distributed.DistributedContext gfn.utils.distributed.DistributedContextmpi4py Functions --------- .. autoapisummary:: gfn.utils.distributed.average_gradients gfn.utils.distributed.average_models gfn.utils.distributed.gather_distributed_data gfn.utils.distributed.initialize_distributed_compute gfn.utils.distributed.initialize_distributed_compute_mpi4py gfn.utils.distributed.report_load_imbalance gfn.utils.distributed.report_time_info Module Contents --------------- .. py:class:: DistributedContext Holds all distributed training/replay buffer groups and ranks. .. py:attribute:: agent_group_id :type: Optional[int] :value: None .. py:attribute:: agent_group_size :type: int .. py:attribute:: agent_groups :type: Optional[List[torch.distributed.ProcessGroup]] :value: None .. py:attribute:: assigned_buffer :type: Optional[int] :value: None .. py:attribute:: assigned_training_ranks :type: Optional[List[int]] :value: None .. py:attribute:: buffer_group :type: Optional[torch.distributed.ProcessGroup] :value: None .. py:method:: cleanup() Cleans up the distributed process group. .. py:attribute:: dc_mpi4py :type: Optional[DistributedContextmpi4py] :value: None .. py:method:: is_buffer_rank() Check if the current rank is part of the buffer group. .. py:method:: is_training_rank() Check if the current rank is part of the training group. .. py:attribute:: my_rank :type: int .. py:attribute:: num_training_ranks :type: int .. py:attribute:: train_global_group :type: Optional[torch.distributed.ProcessGroup] :value: None .. py:attribute:: world_size :type: int .. py:class:: DistributedContextmpi4py Holds all distributed training/replay buffer groups and ranks. .. py:attribute:: agent_group_id :type: Optional[int] :value: None .. py:attribute:: agent_group_size :type: int .. py:attribute:: agent_groups :type: Optional[List[mpi4py.MPI.Comm]] :value: None .. py:attribute:: assigned_buffer :type: Optional[int] :value: None .. py:attribute:: assigned_training_ranks :type: Optional[List[int]] :value: None .. py:attribute:: buffer_group :type: Optional[mpi4py.MPI.Comm] :value: None .. py:method:: is_buffer_rank() Check if the current rank is part of the buffer group. .. py:method:: is_training_rank() Check if the current rank is part of the training group. .. py:attribute:: my_rank :type: int .. py:attribute:: num_training_ranks :type: int .. py:attribute:: train_global_group :type: mpi4py.MPI.Comm .. py:attribute:: world_size :type: int .. py:function:: average_gradients(model) All-Reduce gradients across all models. .. py:function:: average_models(model, training_group=None) Averages model weights across all ranks. .. py:function:: gather_distributed_data(local_tensor, world_size = None, rank = None, training_group=None) Gather data from all processes in a distributed setting. :param local_data: Data from the current process (List or Tensor) :param world_size: Number of processes (optional, will get from env if None) :param rank: Current process rank (optional, will get from env if None) :returns: Concatenated tensor from all processes On other ranks: None :rtype: On rank 0 .. py:function:: initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups) Initializes distributed compute using ccl, mpi, or gloo backends. :param dist_backend: The backend to use for distributed compute. :param num_remote_buffers: The number of remote buffers to use. :param num_agent_groups: The number of agent groups. .. py:function:: initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups) Initializes distributed compute using mpi4py. :param num_remote_buffers: The number of remote buffers to use. :param num_agent_groups: The number of agent groups. .. py:data:: logger .. py:function:: report_load_imbalance(all_timing_dict, world_size) Reports load imbalance and timing information from a timing dictionary. param all_timing_dict: A list of dictionaries containing timing information for each rank. all_timing_dict structure: [rank0_dict, rank1_dict, ...] where each rank_dict is: {"step_name": [iter0_time, iter1_time, iter2_time, ...], ...} param world_size: The total number of ranks in the distributed setup. .. py:function:: report_time_info(all_timing_dict, world_size) Reports timing information from a timing dictionary. param all_timing_dict: A list of dictionaries containing timing information for each rank. all_timing_dict structure: [rank0_dict, rank1_dict, ...] where each rank_dict is: {"step_name": [iter0_time, iter1_time, iter2_time, ...], ...} param world_size: The total number of ranks in the distributed setup.