gfn.utils.distributed

Attributes

logger

Classes

DistributedContext

Holds all distributed training/replay buffer groups and ranks.

DistributedContextmpi4py

Holds all distributed training/replay buffer groups and ranks.

Functions

average_gradients(model)

All-Reduce gradients across all models.

average_models(model[, training_group])

Averages model weights across all ranks.

gather_distributed_data(local_tensor[, world_size, ...])

Gather data from all processes in a distributed setting.

initialize_distributed_compute(dist_backend, ...)

Initializes distributed compute using ccl, mpi, or gloo backends.

initialize_distributed_compute_mpi4py(...)

Initializes distributed compute using mpi4py.

report_load_imbalance(all_timing_dict, world_size)

Reports load imbalance and timing information from a timing dictionary.

report_time_info(all_timing_dict, world_size)

Reports timing information from a timing dictionary.

Module Contents

class gfn.utils.distributed.DistributedContext

Holds all distributed training/replay buffer groups and ranks.

agent_group_id: int | None = None
agent_group_size: int
agent_groups: List[torch.distributed.ProcessGroup] | None = None
assigned_buffer: int | None = None
assigned_training_ranks: List[int] | None = None
buffer_group: torch.distributed.ProcessGroup | None = None
cleanup()

Cleans up the distributed process group.

Return type:

None

dc_mpi4py: DistributedContextmpi4py | None = None
is_buffer_rank()

Check if the current rank is part of the buffer group.

Return type:

bool

is_training_rank()

Check if the current rank is part of the training group.

Return type:

bool

my_rank: int
num_training_ranks: int
train_global_group: torch.distributed.ProcessGroup | None = None
world_size: int
class gfn.utils.distributed.DistributedContextmpi4py

Holds all distributed training/replay buffer groups and ranks.

agent_group_id: int | None = None
agent_group_size: int
agent_groups: List[mpi4py.MPI.Comm] | None = None
assigned_buffer: int | None = None
assigned_training_ranks: List[int] | None = None
buffer_group: mpi4py.MPI.Comm | None = None
is_buffer_rank()

Check if the current rank is part of the buffer group.

Return type:

bool

is_training_rank()

Check if the current rank is part of the training group.

Return type:

bool

my_rank: int
num_training_ranks: int
train_global_group: mpi4py.MPI.Comm
world_size: int
gfn.utils.distributed.average_gradients(model)

All-Reduce gradients across all models.

gfn.utils.distributed.average_models(model, training_group=None)

Averages model weights across all ranks.

gfn.utils.distributed.gather_distributed_data(local_tensor, world_size=None, rank=None, training_group=None)

Gather data from all processes in a distributed setting.

Parameters:
  • local_data – Data from the current process (List or Tensor)

  • world_size (int | None) – Number of processes (optional, will get from env if None)

  • rank (int | None) – Current process rank (optional, will get from env if None)

  • local_tensor (torch.Tensor)

Returns:

Concatenated tensor from all processes On other ranks: None

Return type:

On rank 0

gfn.utils.distributed.initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups)

Initializes distributed compute using ccl, mpi, or gloo backends.

Parameters:
  • dist_backend (str) – The backend to use for distributed compute.

  • num_remote_buffers (int) – The number of remote buffers to use.

  • num_agent_groups (int) – The number of agent groups.

Return type:

DistributedContext

gfn.utils.distributed.initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups)

Initializes distributed compute using mpi4py.

Parameters:
  • num_remote_buffers (int) – The number of remote buffers to use.

  • num_agent_groups (int) – The number of agent groups.

Return type:

DistributedContextmpi4py

gfn.utils.distributed.logger
gfn.utils.distributed.report_load_imbalance(all_timing_dict, world_size)
Reports load imbalance and timing information from a timing dictionary.
param all_timing_dict: A list of dictionaries containing timing information for each rank.

all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}

param world_size: The total number of ranks in the distributed setup.

Parameters:
  • all_timing_dict (List[Dict[str, List[float]]])

  • world_size (int)

Return type:

None

gfn.utils.distributed.report_time_info(all_timing_dict, world_size)
Reports timing information from a timing dictionary.
param all_timing_dict: A list of dictionaries containing timing information for each rank.

all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}

param world_size: The total number of ranks in the distributed setup.

Parameters:
  • all_timing_dict (List[Dict[str, List[float]]])

  • world_size (int)

Return type:

None