gfn.utils.distributed

Attributes

Group

default_backend

logger

Classes

AsyncSendHandle

Handle for a non-blocking send operation.

DistributedContext

Holds all distributed training/replay buffer groups and ranks.

DistributedContextMPI4Py

Holds all distributed training/replay buffer groups and ranks.

Functions

_first_env(*names[, default])

Return the first non-empty environment variable from a list of names.

_get_MPI()

Lazily import and return the mpi4py MPI module.

all_gather(output_list, tensor[, backend, group])

Backend-agnostic all-gather.

all_reduce(tensor[, op, backend, group])

Backend-agnostic in-place all-reduce.

average_gradients(model)

All-Reduce gradients across all models.

average_models(model[, training_group])

Averages model weights across all ranks.

barrier([backend, group])

Backend-agnostic barrier synchronization.

broadcast(tensor, src[, backend, group])

Backend-agnostic broadcast.

gather_distributed_data(local_tensor[, world_size, ...])

Gather data from all processes in a distributed setting.

get_rank([backend, group])

Backend-agnostic rank query.

get_world_size([backend, group])

Backend-agnostic world size query.

initialize_distributed_compute(dist_backend, ...[, ...])

Initializes distributed compute using ccl, mpi, or gloo backends.

initialize_distributed_compute_mpi4py(...[, ...])

Initializes distributed compute using mpi4py.

isend(data, dst_rank[, backend, tag])

Non-blocking send of a byte tensor to dst_rank.

recv([src_rank, backend, tag])

Receive a byte tensor from src_rank (or any rank if None).

report_load_imbalance(all_timing_dict, world_size)

Reports load imbalance and timing information from a timing dictionary.

report_time_info(all_timing_dict, world_size)

Reports timing information from a timing dictionary.

send(data, dst_rank[, backend, tag])

Send a byte tensor to dst_rank.

Module Contents

class gfn.utils.distributed.AsyncSendHandle

Handle for a non-blocking send operation.

The underlying buffer must remain alive until the send completes. Call is_complete() to poll or wait() to block.

_buffer: Any
_request: Any
is_complete()

Check if the send has completed without blocking.

Return type:

bool

wait()

Block until the send completes.

Return type:

None

class gfn.utils.distributed.DistributedContext

Holds all distributed training/replay buffer groups and ranks.

agent_group_id: int | None = None
agent_group_size: int
agent_groups: List[torch.distributed.ProcessGroup] | None = None
assigned_buffer: int | None = None
assigned_training_ranks: List[int] | None = None
buffer_group: torch.distributed.ProcessGroup | None = None
cleanup()

Cleans up the distributed process group.

Return type:

None

coordinator_rank: int | None = None
dc_mpi4py: DistributedContextMPI4Py | None = None
get_train_group(backend='mpi')
Parameters:

backend (str)

is_buffer_rank()

Check if the current rank is part of the buffer group.

Return type:

bool

is_coordinator_rank()

Check if the current rank is the coordinator.

Return type:

bool

is_training_rank()

Check if the current rank is part of the training group.

Return type:

bool

my_rank: int
num_training_ranks: int
train_global_group: torch.distributed.ProcessGroup | None = None
world_size: int
class gfn.utils.distributed.DistributedContextMPI4Py

Holds all distributed training/replay buffer groups and ranks.

agent_group_id: int | None = None
agent_group_size: int
agent_groups: List[mpi4py.MPI.Comm] | None = None
assigned_buffer: int | None = None
assigned_training_ranks: List[int] | None = None
buffer_group: mpi4py.MPI.Comm | None = None
is_buffer_rank()

Check if the current rank is part of the buffer group.

Return type:

bool

is_training_rank()

Check if the current rank is part of the training group.

Return type:

bool

my_rank: int
num_training_ranks: int
train_global_group: mpi4py.MPI.Comm | None = None
world_size: int
gfn.utils.distributed.Group
gfn.utils.distributed._first_env(*names, default=None)

Return the first non-empty environment variable from a list of names.

Parameters:
  • names (str)

  • default (str | None)

Return type:

str | None

gfn.utils.distributed._get_MPI()

Lazily import and return the mpi4py MPI module.

gfn.utils.distributed.all_gather(output_list, tensor, backend=default_backend, group=None)

Backend-agnostic all-gather.

The MPI backend round-trips through CPU/numpy and copies results back to the output tensors (preserving their devices).

Parameters:
  • output_list (list[torch.Tensor]) – List of pre-allocated tensors (one per rank) to receive gathered data into.

  • tensor (torch.Tensor) – The local tensor to send.

  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

None

gfn.utils.distributed.all_reduce(tensor, op='SUM', backend=default_backend, group=None)

Backend-agnostic in-place all-reduce.

The MPI backend round-trips through CPU/numpy and copies the result back to the original tensor (preserving its device).

Parameters:
  • tensor (torch.Tensor) – The tensor to reduce in-place.

  • op (str) – Reduction operation. One of "SUM", "MAX", "MIN".

  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

None

gfn.utils.distributed.average_gradients(model)

All-Reduce gradients across all models.

gfn.utils.distributed.average_models(model, training_group=None)

Averages model weights across all ranks.

gfn.utils.distributed.barrier(backend=default_backend, group=None)

Backend-agnostic barrier synchronization.

Parameters:
  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

None

gfn.utils.distributed.broadcast(tensor, src, backend=default_backend, group=None)

Backend-agnostic broadcast.

The MPI backend round-trips through CPU/numpy and copies the result back to the tensor (preserving its device).

Parameters:
  • tensor (torch.Tensor) – The tensor to broadcast. On the source rank this is the data to send; on other ranks the buffer is overwritten with received data.

  • src (int) – Source rank.

  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

None

gfn.utils.distributed.default_backend = 'mpi'
gfn.utils.distributed.gather_distributed_data(local_tensor, world_size=None, rank=None, training_group=None)

Gather data from all processes in a distributed setting.

Parameters:
  • local_data – Data from the current process (List or Tensor)

  • world_size (int | None) – Number of processes (optional, will get from env if None)

  • rank (int | None) – Current process rank (optional, will get from env if None)

  • local_tensor (torch.Tensor)

Returns:

Concatenated tensor from all processes On other ranks: None

Return type:

On rank 0

gfn.utils.distributed.get_rank(backend=default_backend, group=None)

Backend-agnostic rank query.

Parameters:
  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

int

gfn.utils.distributed.get_world_size(backend=default_backend, group=None)

Backend-agnostic world size query.

Parameters:
  • backend (str) – "torch" or "mpi".

  • group (Group | None) – Process group (torch ProcessGroup or MPI communicator).

Return type:

int

gfn.utils.distributed.initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups, use_coordinator=False)

Initializes distributed compute using ccl, mpi, or gloo backends.

Parameters:
  • dist_backend (str) – The backend to use for distributed compute.

  • num_remote_buffers (int) – The number of remote buffers to use.

  • num_agent_groups (int) – The number of agent groups.

  • use_coordinator (bool) – If True, the last rank becomes a coordinator that aggregates mode discoveries across buffer managers.

Return type:

DistributedContext

gfn.utils.distributed.initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups, num_coordinators=0)

Initializes distributed compute using mpi4py.

Parameters:
  • num_remote_buffers (int) – The number of remote buffers to use.

  • num_agent_groups (int) – The number of agent groups.

  • num_coordinators (int) – Number of coordinator ranks (0 or 1).

Return type:

DistributedContextMPI4Py

gfn.utils.distributed.isend(data, dst_rank, backend=default_backend, tag=0)

Non-blocking send of a byte tensor to dst_rank.

Returns an AsyncSendHandle whose internal buffer must be kept alive until is_complete() returns True or wait() is called.

Parameters:
  • data (torch.Tensor) – Tensor to send (will be cast to uint8).

  • dst_rank (int) – Destination rank (global).

  • backend (str) – "torch" or "mpi".

  • tag (int) – MPI/torch tag for message matching.

Returns:

A handle that tracks the outstanding send.

Return type:

AsyncSendHandle

gfn.utils.distributed.logger
gfn.utils.distributed.recv(src_rank=None, backend=default_backend, tag=0)

Receive a byte tensor from src_rank (or any rank if None).

Returns (source_rank, data) where data is a uint8 tensor. See send() for protocol details per backend.

Parameters:
  • src_rank (int | None) – Source rank to receive from, or None for any source.

  • backend (str) – "torch" or "mpi".

  • tag (int) – MPI/torch tag for message matching. Must match the tag used by the corresponding send() call.

Returns:

Tuple of (source rank, received uint8 tensor).

Return type:

tuple[int, torch.Tensor]

gfn.utils.distributed.report_load_imbalance(all_timing_dict, world_size)
Reports load imbalance and timing information from a timing dictionary.
param all_timing_dict: A list of dictionaries containing timing information for each rank.

all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}

param world_size: The total number of ranks in the distributed setup.

Parameters:
  • all_timing_dict (List[Dict[str, List[float]]])

  • world_size (int)

Return type:

None

gfn.utils.distributed.report_time_info(all_timing_dict, world_size)
Reports timing information from a timing dictionary.
param all_timing_dict: A list of dictionaries containing timing information for each rank.

all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}

param world_size: The total number of ranks in the distributed setup.

Parameters:
  • all_timing_dict (List[Dict[str, List[float]]])

  • world_size (int)

Return type:

None

gfn.utils.distributed.send(data, dst_rank, backend=default_backend, tag=0)

Send a byte tensor to dst_rank.

This is byte-level transport — the payload is always sent as raw uint8 bytes. Both backends guarantee that recv on the other end will return an identical uint8 tensor.

Protocol differences between backends:

  • torch: Uses a length-prefixed two-message protocol. First a 1-element int64 tensor containing the payload length is sent (tag=2*tag), then the payload itself (tag=2*tag+1). This lets the receiver allocate the right buffer size before the data arrives.

  • mpi: Sends a single message with the given tag. The receiver uses MPI.Probe to discover the incoming message size before calling Recv, so no separate length message is needed.

Because the wire protocols differ, sender and receiver must use the same backend.

Parameters:
  • data (torch.Tensor) – Tensor to send (will be cast to uint8).

  • dst_rank (int) – Destination rank (global).

  • backend (str) – "torch" or "mpi".

  • tag (int) – MPI/torch tag for message matching. Use distinct tags to multiplex independent message channels on the same rank pair.

Return type:

None