gfn.utils.distributed ===================== .. py:module:: gfn.utils.distributed Attributes ---------- .. autoapisummary:: gfn.utils.distributed.Group gfn.utils.distributed.default_backend gfn.utils.distributed.logger Classes ------- .. autoapisummary:: gfn.utils.distributed.AsyncSendHandle gfn.utils.distributed.DistributedContext gfn.utils.distributed.DistributedContextMPI4Py Functions --------- .. autoapisummary:: gfn.utils.distributed._first_env gfn.utils.distributed._get_MPI gfn.utils.distributed.all_gather gfn.utils.distributed.all_reduce gfn.utils.distributed.average_gradients gfn.utils.distributed.average_models gfn.utils.distributed.barrier gfn.utils.distributed.broadcast gfn.utils.distributed.gather_distributed_data gfn.utils.distributed.get_rank gfn.utils.distributed.get_world_size gfn.utils.distributed.initialize_distributed_compute gfn.utils.distributed.initialize_distributed_compute_mpi4py gfn.utils.distributed.isend gfn.utils.distributed.recv gfn.utils.distributed.report_load_imbalance gfn.utils.distributed.report_time_info gfn.utils.distributed.send Module Contents --------------- .. py:class:: AsyncSendHandle Handle for a non-blocking send operation. The underlying buffer must remain alive until the send completes. Call :meth:`is_complete` to poll or :meth:`wait` to block. .. py:attribute:: _buffer :type: Any .. py:attribute:: _request :type: Any .. py:method:: is_complete() Check if the send has completed without blocking. .. py:method:: wait() Block until the send completes. .. py:class:: DistributedContext Holds all distributed training/replay buffer groups and ranks. .. py:attribute:: agent_group_id :type: Optional[int] :value: None .. py:attribute:: agent_group_size :type: int .. py:attribute:: agent_groups :type: Optional[List[torch.distributed.ProcessGroup]] :value: None .. py:attribute:: assigned_buffer :type: Optional[int] :value: None .. py:attribute:: assigned_training_ranks :type: Optional[List[int]] :value: None .. py:attribute:: buffer_group :type: Optional[torch.distributed.ProcessGroup] :value: None .. py:method:: cleanup() Cleans up the distributed process group. .. py:attribute:: coordinator_rank :type: Optional[int] :value: None .. py:attribute:: dc_mpi4py :type: Optional[DistributedContextMPI4Py] :value: None .. py:method:: get_train_group(backend = 'mpi') .. py:method:: is_buffer_rank() Check if the current rank is part of the buffer group. .. py:method:: is_coordinator_rank() Check if the current rank is the coordinator. .. py:method:: is_training_rank() Check if the current rank is part of the training group. .. py:attribute:: my_rank :type: int .. py:attribute:: num_training_ranks :type: int .. py:attribute:: train_global_group :type: Optional[torch.distributed.ProcessGroup] :value: None .. py:attribute:: world_size :type: int .. py:class:: DistributedContextMPI4Py Holds all distributed training/replay buffer groups and ranks. .. py:attribute:: agent_group_id :type: Optional[int] :value: None .. py:attribute:: agent_group_size :type: int .. py:attribute:: agent_groups :type: Optional[List[mpi4py.MPI.Comm]] :value: None .. py:attribute:: assigned_buffer :type: Optional[int] :value: None .. py:attribute:: assigned_training_ranks :type: Optional[List[int]] :value: None .. py:attribute:: buffer_group :type: Optional[mpi4py.MPI.Comm] :value: None .. py:method:: is_buffer_rank() Check if the current rank is part of the buffer group. .. py:method:: is_training_rank() Check if the current rank is part of the training group. .. py:attribute:: my_rank :type: int .. py:attribute:: num_training_ranks :type: int .. py:attribute:: train_global_group :type: Optional[mpi4py.MPI.Comm] :value: None .. py:attribute:: world_size :type: int .. py:data:: Group .. py:function:: _first_env(*names, default = None) Return the first non-empty environment variable from a list of names. .. py:function:: _get_MPI() Lazily import and return the mpi4py MPI module. .. py:function:: all_gather(output_list, tensor, backend = default_backend, group = None) Backend-agnostic all-gather. The MPI backend round-trips through CPU/numpy and copies results back to the output tensors (preserving their devices). :param output_list: List of pre-allocated tensors (one per rank) to receive gathered data into. :param tensor: The local tensor to send. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:function:: all_reduce(tensor, op = 'SUM', backend = default_backend, group = None) Backend-agnostic in-place all-reduce. The MPI backend round-trips through CPU/numpy and copies the result back to the original tensor (preserving its device). :param tensor: The tensor to reduce in-place. :param op: Reduction operation. One of ``"SUM"``, ``"MAX"``, ``"MIN"``. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:function:: average_gradients(model) All-Reduce gradients across all models. .. py:function:: average_models(model, training_group=None) Averages model weights across all ranks. .. py:function:: barrier(backend = default_backend, group = None) Backend-agnostic barrier synchronization. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:function:: broadcast(tensor, src, backend = default_backend, group = None) Backend-agnostic broadcast. The MPI backend round-trips through CPU/numpy and copies the result back to the tensor (preserving its device). :param tensor: The tensor to broadcast. On the source rank this is the data to send; on other ranks the buffer is overwritten with received data. :param src: Source rank. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:data:: default_backend :value: 'mpi' .. py:function:: gather_distributed_data(local_tensor, world_size = None, rank = None, training_group=None) Gather data from all processes in a distributed setting. :param local_data: Data from the current process (List or Tensor) :param world_size: Number of processes (optional, will get from env if None) :param rank: Current process rank (optional, will get from env if None) :returns: Concatenated tensor from all processes On other ranks: None :rtype: On rank 0 .. py:function:: get_rank(backend = default_backend, group = None) Backend-agnostic rank query. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:function:: get_world_size(backend = default_backend, group = None) Backend-agnostic world size query. :param backend: ``"torch"`` or ``"mpi"``. :param group: Process group (torch ProcessGroup or MPI communicator). .. py:function:: initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups, use_coordinator = False) Initializes distributed compute using ccl, mpi, or gloo backends. :param dist_backend: The backend to use for distributed compute. :param num_remote_buffers: The number of remote buffers to use. :param num_agent_groups: The number of agent groups. :param use_coordinator: If True, the last rank becomes a coordinator that aggregates mode discoveries across buffer managers. .. py:function:: initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups, num_coordinators = 0) Initializes distributed compute using mpi4py. :param num_remote_buffers: The number of remote buffers to use. :param num_agent_groups: The number of agent groups. :param num_coordinators: Number of coordinator ranks (0 or 1). .. py:function:: isend(data, dst_rank, backend = default_backend, tag = 0) Non-blocking send of a byte tensor to ``dst_rank``. Returns an :class:`AsyncSendHandle` whose internal buffer **must** be kept alive until :meth:`~AsyncSendHandle.is_complete` returns ``True`` or :meth:`~AsyncSendHandle.wait` is called. :param data: Tensor to send (will be cast to uint8). :param dst_rank: Destination rank (global). :param backend: ``"torch"`` or ``"mpi"``. :param tag: MPI/torch tag for message matching. :returns: A handle that tracks the outstanding send. .. py:data:: logger .. py:function:: recv(src_rank = None, backend = default_backend, tag = 0) Receive a byte tensor from ``src_rank`` (or any rank if ``None``). Returns ``(source_rank, data)`` where ``data`` is a uint8 tensor. See :func:`send` for protocol details per backend. :param src_rank: Source rank to receive from, or ``None`` for any source. :param backend: ``"torch"`` or ``"mpi"``. :param tag: MPI/torch tag for message matching. Must match the tag used by the corresponding :func:`send` call. :returns: Tuple of (source rank, received uint8 tensor). .. py:function:: report_load_imbalance(all_timing_dict, world_size) Reports load imbalance and timing information from a timing dictionary. param all_timing_dict: A list of dictionaries containing timing information for each rank. all_timing_dict structure: [rank0_dict, rank1_dict, ...] where each rank_dict is: {"step_name": [iter0_time, iter1_time, iter2_time, ...], ...} param world_size: The total number of ranks in the distributed setup. .. py:function:: report_time_info(all_timing_dict, world_size) Reports timing information from a timing dictionary. param all_timing_dict: A list of dictionaries containing timing information for each rank. all_timing_dict structure: [rank0_dict, rank1_dict, ...] where each rank_dict is: {"step_name": [iter0_time, iter1_time, iter2_time, ...], ...} param world_size: The total number of ranks in the distributed setup. .. py:function:: send(data, dst_rank, backend = default_backend, tag = 0) Send a byte tensor to ``dst_rank``. This is **byte-level transport** — the payload is always sent as raw uint8 bytes. Both backends guarantee that ``recv`` on the other end will return an identical uint8 tensor. Protocol differences between backends: - **torch**: Uses a length-prefixed two-message protocol. First a 1-element int64 tensor containing the payload length is sent (tag=2*tag), then the payload itself (tag=2*tag+1). This lets the receiver allocate the right buffer size before the data arrives. - **mpi**: Sends a single message with the given ``tag``. The receiver uses ``MPI.Probe`` to discover the incoming message size before calling ``Recv``, so no separate length message is needed. Because the wire protocols differ, sender and receiver **must** use the same backend. :param data: Tensor to send (will be cast to uint8). :param dst_rank: Destination rank (global). :param backend: ``"torch"`` or ``"mpi"``. :param tag: MPI/torch tag for message matching. Use distinct tags to multiplex independent message channels on the same rank pair.