gfn.utils.distributed¶
Attributes¶
Classes¶
Holds all distributed training/replay buffer groups and ranks. |
|
Holds all distributed training/replay buffer groups and ranks. |
Functions¶
|
All-Reduce gradients across all models. |
|
Averages model weights across all ranks. |
|
Gather data from all processes in a distributed setting. |
|
Initializes distributed compute using ccl, mpi, or gloo backends. |
Initializes distributed compute using mpi4py. |
|
|
Reports load imbalance and timing information from a timing dictionary. |
|
Reports timing information from a timing dictionary. |
Module Contents¶
- class gfn.utils.distributed.DistributedContext¶
Holds all distributed training/replay buffer groups and ranks.
- agent_group_id: int | None = None¶
- agent_group_size: int¶
- agent_groups: List[torch.distributed.ProcessGroup] | None = None¶
- assigned_buffer: int | None = None¶
- assigned_training_ranks: List[int] | None = None¶
- buffer_group: torch.distributed.ProcessGroup | None = None¶
- cleanup()¶
Cleans up the distributed process group.
- Return type:
None
- dc_mpi4py: DistributedContextmpi4py | None = None¶
- is_buffer_rank()¶
Check if the current rank is part of the buffer group.
- Return type:
bool
- is_training_rank()¶
Check if the current rank is part of the training group.
- Return type:
bool
- my_rank: int¶
- num_training_ranks: int¶
- train_global_group: torch.distributed.ProcessGroup | None = None¶
- world_size: int¶
- class gfn.utils.distributed.DistributedContextmpi4py¶
Holds all distributed training/replay buffer groups and ranks.
- agent_group_id: int | None = None¶
- agent_group_size: int¶
- agent_groups: List[mpi4py.MPI.Comm] | None = None¶
- assigned_buffer: int | None = None¶
- assigned_training_ranks: List[int] | None = None¶
- buffer_group: mpi4py.MPI.Comm | None = None¶
- is_buffer_rank()¶
Check if the current rank is part of the buffer group.
- Return type:
bool
- is_training_rank()¶
Check if the current rank is part of the training group.
- Return type:
bool
- my_rank: int¶
- num_training_ranks: int¶
- train_global_group: mpi4py.MPI.Comm¶
- world_size: int¶
- gfn.utils.distributed.average_gradients(model)¶
All-Reduce gradients across all models.
- gfn.utils.distributed.average_models(model, training_group=None)¶
Averages model weights across all ranks.
- gfn.utils.distributed.gather_distributed_data(local_tensor, world_size=None, rank=None, training_group=None)¶
Gather data from all processes in a distributed setting.
- Parameters:
local_data – Data from the current process (List or Tensor)
world_size (int | None) – Number of processes (optional, will get from env if None)
rank (int | None) – Current process rank (optional, will get from env if None)
local_tensor (torch.Tensor)
- Returns:
Concatenated tensor from all processes On other ranks: None
- Return type:
On rank 0
- gfn.utils.distributed.initialize_distributed_compute(dist_backend, num_remote_buffers, num_agent_groups)¶
Initializes distributed compute using ccl, mpi, or gloo backends.
- Parameters:
dist_backend (str) – The backend to use for distributed compute.
num_remote_buffers (int) – The number of remote buffers to use.
num_agent_groups (int) – The number of agent groups.
- Return type:
- gfn.utils.distributed.initialize_distributed_compute_mpi4py(num_remote_buffers, num_agent_groups)¶
Initializes distributed compute using mpi4py.
- Parameters:
num_remote_buffers (int) – The number of remote buffers to use.
num_agent_groups (int) – The number of agent groups.
- Return type:
- gfn.utils.distributed.logger¶
- gfn.utils.distributed.report_load_imbalance(all_timing_dict, world_size)¶
- Reports load imbalance and timing information from a timing dictionary.
- param all_timing_dict: A list of dictionaries containing timing information for each rank.
all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}
param world_size: The total number of ranks in the distributed setup.
- Parameters:
all_timing_dict (List[Dict[str, List[float]]])
world_size (int)
- Return type:
None
- gfn.utils.distributed.report_time_info(all_timing_dict, world_size)¶
- Reports timing information from a timing dictionary.
- param all_timing_dict: A list of dictionaries containing timing information for each rank.
all_timing_dict structure: [rank0_dict, rank1_dict, …] where each rank_dict is: {“step_name”: [iter0_time, iter1_time, iter2_time, …], …}
param world_size: The total number of ranks in the distributed setup.
- Parameters:
all_timing_dict (List[Dict[str, List[float]]])
world_size (int)
- Return type:
None