gfn.containers.replay_buffer_manager¶
Attributes¶
Classes¶
Receives training containers on the manager rank and replies with scores. |
Module Contents¶
- gfn.containers.replay_buffer_manager.DATA_TAG = 0¶
- gfn.containers.replay_buffer_manager.METADATA_TAG = 1¶
- class gfn.containers.replay_buffer_manager.ReplayBufferManager(env, rank, num_training_ranks, scoring_function=None, diverse_replay_buffer=False, capacity=10000, remote_manager_rank=None, communication_backend='mpi', timing=False, store_locally=True, baseline_strategy='min', baseline_percentile=0.1, baseline_ema_alpha=0.1)¶
Receives training containers on the manager rank and replies with scores.
The manager optionally stores incoming data in a local replay buffer (
store_locally=True) and injects abaseline_log_rewardinto every score response so workers withbaseline_filteringcan skip sending payloads that would be immediately evicted. The baseline source is controlled bybaseline_strategy:"min"/"percentile"read from the local buffer once it reaches capacity."ema"(and any strategy when the buffer is unavailable, e.g.store_locally=False) reads from a running EMA of incoming batch minima.
- Parameters:
env (gfn.env.Env)
rank (int)
num_training_ranks (int)
scoring_function (Optional[Callable[Ellipsis, dict[str, float]]])
diverse_replay_buffer (bool)
capacity (int)
remote_manager_rank (int | None)
communication_backend (str)
timing (bool)
store_locally (bool)
baseline_strategy (str)
baseline_percentile (float)
baseline_ema_alpha (float)
- _baseline_ema: float | None = None¶
- _comm_stats: dict[int, dict]¶
- abstract _compute_metadata()¶
- Return type:
dict
- _handle_message_async(sender_rank, msg, msg_data_len=0)¶
Dispatch a message using non-blocking
isendfor responses.- Parameters:
sender_rank (int)
msg_data_len (int)
- _handle_message_sync(sender_rank, msg, msg_data_len=0)¶
Dispatch a message using blocking
sendfor responses.Simpler than the async variant and uses zero CPU while the send is in flight, making it preferable when all ranks share a CPU.
- Parameters:
sender_rank (int)
msg_data_len (int)
- _inject_baseline_log_reward(score_dict, incoming)¶
Inject
baseline_log_rewardinto score_dict perbaseline_strategy.Updates the EMA tracker from
incoming.log_rewards(used as the source under"ema"and as a fallback when the local buffer is unavailable), then picks a baseline from the buffer ("min"/"percentile"once at capacity) or the EMA. Non-finite rewards are excluded so containers with-inf(e.g. Transitions) do not poison the statistics.- Parameters:
score_dict (dict[str, float])
- Return type:
None
- _pending_sends: list[gfn.utils.distributed.AsyncSendHandle] = []¶
- _print_timing_summary()¶
Print communication and timing stats at shutdown.
- Return type:
None
- _prune_completed_sends()¶
Remove completed non-blocking sends from the pending list.
- Return type:
None
- _recv_object()¶
- _timing_data: dict[str, list[float]]¶
- baseline_ema_alpha = 0.1¶
- baseline_percentile = 0.1¶
- baseline_strategy = 'min'¶
- capacity = 10000¶
- communication_backend = 'mpi'¶
- default_scoring_function(obj, sender_rank=-1)¶
Default score function if none provided, placeholder.
- Parameters:
sender_rank (int)
- Return type:
dict[str, float]
- diverse_replay_buffer = False¶
- exit_counter = 0¶
- static get_metadata(manager_rank, backend)¶
Sends a get metadata signal to the replay buffer manager.
Uses
METADATA_TAGso the response is never confused with pending data/score messages on the default tag.- Parameters:
manager_rank (int)
backend (str)
- Return type:
dict
- is_running = True¶
- num_training_ranks¶
- rank¶
- remote_manager_rank = None¶
- run(async_send=True)¶
Runs on remote buffer manager ranks. Waits for training data, computes reward, sends back scores.
- Parameters:
async_send (bool) – If True (default), use non-blocking
isendfor responses. If False, use blockingsendfor responses.
- scoring_function¶
- static send_termination_signal(manager_rank, backend)¶
Sends a termination signal to the replay buffer manager.
- Parameters:
manager_rank (int)
backend (str)
- Return type:
None
- store_locally = True¶
- timing = False¶