gfn.utils.distributed¶
Attributes¶
Classes¶
Handle for a non-blocking send operation. |
|
Holds all distributed training/replay buffer groups and ranks. |
|
Holds all distributed training/replay buffer groups and ranks. |
Functions¶
|
Return the first non-empty environment variable from a list of names. |
|
Lazily import and return the mpi4py MPI module. |
|
Backend-agnostic all-gather. |
|
Backend-agnostic in-place all-reduce. |
|
All-Reduce gradients across all models. |
|
Averages model weights across all ranks. |
|
Backend-agnostic barrier synchronization. |
|
Backend-agnostic broadcast. |
|
Gather data from all processes in a distributed setting. |
|
Backend-agnostic rank query. |
|
Backend-agnostic world size query. |
|
Initializes distributed compute using ccl, mpi, or gloo backends. |
|
Initializes distributed compute using mpi4py. |
|
Non-blocking send of a byte tensor to |
|
Receive a byte tensor from |
|
Reports load imbalance and timing information from a timing dictionary. |
|
Reports timing information from a timing dictionary. |
|
Send a byte tensor to |
Module Contents¶
- class gfn.utils.distributed.AsyncSendHandle¶
Handle for a non-blocking send operation.
The underlying buffer must remain alive until the send completes. Call
is_complete()to poll orwait()to block.- _buffer: Any¶
- _request: Any¶
- is_complete()¶
Check if the send has completed without blocking.
- Return type:
bool
- wait()¶
Block until the send completes.
- Return type:
None
- class gfn.utils.distributed.DistributedContext¶
Holds all distributed training/replay buffer groups and ranks.
- agent_group_id: int | None = None¶
- agent_group_size: int¶
- agent_groups: List[torch.distributed.ProcessGroup] | None = None¶
- assigned_buffer: int | None = None¶
- assigned_training_ranks: List[int] | None = None¶
- buffer_group: torch.distributed.ProcessGroup | None = None¶
- cleanup()¶
Cleans up the distributed process group.
- Return type:
None
- coordinator_rank: int | None = None¶
- dc_mpi4py: DistributedContextMPI4Py | None = None¶
- get_train_group(backend='mpi')¶
- Parameters:
backend (str)
- is_buffer_rank()¶
Check if the current rank is part of the buffer group.
- Return type:
bool
- is_coordinator_rank()¶
Check if the current rank is the coordinator.
- Return type:
bool
- is_training_rank()¶
Check if the current rank is part of the training group.
- Return type:
bool
- my_rank: int¶
- num_training_ranks: int¶
- train_global_group: torch.distributed.ProcessGroup | None = None¶
- world_size: int¶
- class gfn.utils.distributed.DistributedContextMPI4Py¶
Holds all distributed training/replay buffer groups and ranks.
- agent_group_id: int | None = None¶
- agent_group_size: int¶
- agent_groups: List[mpi4py.MPI.Comm] | None = None¶
- assigned_buffer: int | None = None¶
- assigned_training_ranks: List[int] | None = None¶
- buffer_group: mpi4py.MPI.Comm | None = None¶
- is_buffer_rank()¶
Check if the current rank is part of the buffer group.
- Return type:
bool
- is_training_rank()¶
Check if the current rank is part of the training group.
- Return type:
bool
- my_rank: int¶
- num_training_ranks: int¶
- train_global_group: mpi4py.MPI.Comm | None = None¶
- world_size: int¶
- gfn.utils.distributed.Group¶
- gfn.utils.distributed._first_env(*names, default=None)¶
Return the first non-empty environment variable from a list of names.
- Parameters:
names (str)
default (str | None)
- Return type:
str | None
- gfn.utils.distributed._get_MPI()¶
Lazily import and return the mpi4py MPI module.
- gfn.utils.distributed.all_gather(output_list, tensor, backend=default_backend, group=None)¶
Backend-agnostic all-gather.
The MPI backend round-trips through CPU/numpy and copies results back to the output tensors (preserving their devices).
- Parameters:
output_list (list[torch.Tensor]) – List of pre-allocated tensors (one per rank) to receive gathered data into.
tensor (torch.Tensor) – The local tensor to send.
backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
None
- gfn.utils.distributed.all_reduce(tensor, op='SUM', backend=default_backend, group=None)¶
Backend-agnostic in-place all-reduce.
The MPI backend round-trips through CPU/numpy and copies the result back to the original tensor (preserving its device).
- Parameters:
tensor (torch.Tensor) – The tensor to reduce in-place.
op (str) – Reduction operation. One of
"SUM","MAX","MIN".backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
None
- gfn.utils.distributed.average_gradients(model)¶
All-Reduce gradients across all models.
- gfn.utils.distributed.average_models(model, training_group=None)¶
Averages model weights across all ranks.
- gfn.utils.distributed.barrier(backend=default_backend, group=None)¶
Backend-agnostic barrier synchronization.
- Parameters:
backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
None
- gfn.utils.distributed.broadcast(tensor, src, backend=default_backend, group=None)¶
Backend-agnostic broadcast.
The MPI backend round-trips through CPU/numpy and copies the result back to the tensor (preserving its device).
- Parameters:
tensor (torch.Tensor) – The tensor to broadcast. On the source rank this is the data to send; on other ranks the buffer is overwritten with received data.
src (int) – Source rank.
backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
None
- gfn.utils.distributed.default_backend = 'mpi'¶
- gfn.utils.distributed.gather_distributed_data(local_tensor, world_size=None, rank=None, training_group=None)¶
Gather data from all processes in a distributed setting.
- Parameters:
local_data – Data from the current process (List or Tensor)
world_size (int | None) – Number of processes (optional, will get from env if None)
rank (int | None) – Current process rank (optional, will get from env if None)
local_tensor (torch.Tensor)
- Returns:
Concatenated tensor from all processes On other ranks: None
- Return type:
On rank 0
- gfn.utils.distributed.get_rank(backend=default_backend, group=None)¶
Backend-agnostic rank query.
- Parameters:
backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
int
- gfn.utils.distributed.get_world_size(backend=default_backend, group=None)¶
Backend-agnostic world size query.
- Parameters:
backend (str) –
"torch"or"mpi".group (Group | None) – Process group (torch ProcessGroup or MPI communicator).
- Return type:
int
- gfn.utils.distributed.initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups, use_coordinator=False)¶
Initializes distributed compute using ccl, mpi, or gloo backends.
- Parameters:
dist_backend (str) – The backend to use for distributed compute.
num_remote_buffers (int) – The number of remote buffers to use.
num_agent_groups (int) – The number of agent groups.
use_coordinator (bool) – If True, the last rank becomes a coordinator that aggregates mode discoveries across buffer managers.
- Return type:
- gfn.utils.distributed.initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups, num_coordinators=0)¶
Initializes distributed compute using mpi4py.
- Parameters:
num_remote_buffers (int) – The number of remote buffers to use.
num_agent_groups (int) – The number of agent groups.
num_coordinators (int) – Number of coordinator ranks (0 or 1).
- Return type:
- gfn.utils.distributed.isend(data, dst_rank, backend=default_backend, tag=0)¶
Non-blocking send of a byte tensor to
dst_rank.Returns an
AsyncSendHandlewhose internal buffer must be kept alive untilis_complete()returnsTrueorwait()is called.- Parameters:
data (torch.Tensor) – Tensor to send (will be cast to uint8).
dst_rank (int) – Destination rank (global).
backend (str) –
"torch"or"mpi".tag (int) – MPI/torch tag for message matching.
- Returns:
A handle that tracks the outstanding send.
- Return type:
- gfn.utils.distributed.logger¶
- gfn.utils.distributed.recv(src_rank=None, backend=default_backend, tag=0)¶
Receive a byte tensor from
src_rank(or any rank ifNone).Returns
(source_rank, data)wheredatais a uint8 tensor. Seesend()for protocol details per backend.- Parameters:
src_rank (int | None) – Source rank to receive from, or
Nonefor any source.backend (str) –
"torch"or"mpi".tag (int) – MPI/torch tag for message matching. Must match the tag used by the corresponding
send()call.
- Returns:
Tuple of (source rank, received uint8 tensor).
- Return type:
tuple[int, torch.Tensor]
- gfn.utils.distributed.report_load_imbalance(all_timing_dict, world_size)¶
- Reports load imbalance and timing information from a timing dictionary.
- param all_timing_dict: A list of dictionaries containing timing information for each rank.
all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}
param world_size: The total number of ranks in the distributed setup.
- Parameters:
all_timing_dict (List[Dict[str, List[float]]])
world_size (int)
- Return type:
None
- gfn.utils.distributed.report_time_info(all_timing_dict, world_size)¶
- Reports timing information from a timing dictionary.
- param all_timing_dict: A list of dictionaries containing timing information for each rank.
all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}
param world_size: The total number of ranks in the distributed setup.
- Parameters:
all_timing_dict (List[Dict[str, List[float]]])
world_size (int)
- Return type:
None
- gfn.utils.distributed.send(data, dst_rank, backend=default_backend, tag=0)¶
Send a byte tensor to
dst_rank.This is byte-level transport — the payload is always sent as raw uint8 bytes. Both backends guarantee that
recvon the other end will return an identical uint8 tensor.Protocol differences between backends:
torch: Uses a length-prefixed two-message protocol. First a 1-element int64 tensor containing the payload length is sent (tag=2*tag), then the payload itself (tag=2*tag+1). This lets the receiver allocate the right buffer size before the data arrives.
mpi: Sends a single message with the given
tag. The receiver usesMPI.Probeto discover the incoming message size before callingRecv, so no separate length message is needed.
Because the wire protocols differ, sender and receiver must use the same backend.
- Parameters:
data (torch.Tensor) – Tensor to send (will be cast to uint8).
dst_rank (int) – Destination rank (global).
backend (str) –
"torch"or"mpi".tag (int) – MPI/torch tag for message matching. Use distinct tags to multiplex independent message channels on the same rank pair.
- Return type:
None