gfn.containers.replay_buffer ============================ .. py:module:: gfn.containers.replay_buffer Attributes ---------- .. autoapisummary:: gfn.containers.replay_buffer.ContainerUnion Classes ------- .. autoapisummary:: gfn.containers.replay_buffer.Container gfn.containers.replay_buffer.NormBasedDiversePrioritizedReplayBuffer gfn.containers.replay_buffer.ReplayBuffer gfn.containers.replay_buffer.TerminatingStateBuffer Module Contents --------------- .. py:class:: Container Bases: :py:obj:`Protocol` Base class for protocol classes. Protocol classes are defined as:: class Proto(Protocol): def meth(self) -> int: ... Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing). For example:: class C: def meth(self) -> int: return 0 def func(x: Proto) -> int: return x.meth() func(C()) # Passes static type check See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as:: class GenProto(Protocol[T]): def meth(self) -> T: ... .. py:method:: __getitem__(idx) .. py:method:: __len__() .. py:method:: extend(other) .. py:property:: log_rewards :type: torch.Tensor | None .. py:property:: terminating_states .. py:data:: ContainerUnion .. py:class:: NormBasedDiversePrioritizedReplayBuffer(env, capacity = 1000, cutoff_distance = 0.0, p_norm_distance = 1.0, remote_manager_rank = None, remote_buffer_freq = 1) Bases: :py:obj:`ReplayBuffer` A replay buffer with diversity-based prioritization. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object. .. attribute:: prioritized_capacity Whether to use prioritized capacity (keep highest-reward items). This is set to True by default. .. attribute:: prioritized_sampling Whether to sample items with probability proportional to their reward. .. attribute:: cutoff_distance Threshold used to determine whether a new terminating state is different enough from those already in the buffer. .. attribute:: p_norm_distance p-norm value for distance calculation (used in torch.cdist). .. py:method:: _diversity_repr(container) :staticmethod: Returns the tensor used for pairwise distance in diversity filtering. For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct. .. py:method:: add(training_container) Adds a training object to the buffer with diversity-based prioritization. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: cutoff_distance :value: 0.0 .. py:attribute:: p_norm_distance :value: 1.0 .. py:class:: ReplayBuffer(env, capacity = 1000, prioritized_capacity = False, prioritized_sampling = False, remote_manager_rank = None, remote_buffer_freq = 1) A replay buffer for storing containers. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object. .. attribute:: prioritized_capacity Whether to use prioritized capacity (keep highest-reward items). .. attribute:: prioritized_sampling Whether to sample items with probability proportional to their reward. .. py:method:: __len__() Returns the number of items in the buffer. :returns: The number of items in the buffer. .. py:method:: __repr__() Returns a string representation of the ReplayBuffer. :returns: A string summary of the buffer. .. py:attribute:: _add_counter :value: 0 .. py:method:: _add_objs(training_container) Adds a training object to the buffer, handling the capacity. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: _is_full :value: False .. py:method:: _send_objs(training_container) Sends a training container to the remote manager. .. py:method:: add(training_container) Adds a training container to the buffer. The type of the training container is dynamically set based on the type of the first added container. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: capacity :value: 1000 .. py:property:: device :type: torch.device The device on which the buffer's data is stored. :returns: The device object of the buffer's contents. .. py:attribute:: env .. py:method:: initialize(training_container) Initializes the buffer with the type of the first added object. :param training_container: The initial Trajectories, Transitions, or StatesContainer object to set the buffer type. .. py:method:: load(path) Loads buffer contents from a ``.pt`` file saved by :meth:`save`. :param path: File path to the saved buffer. .. py:attribute:: pending_container :type: ContainerUnion | None :value: None .. py:attribute:: prioritized_capacity :value: False .. py:attribute:: prioritized_sampling :value: False .. py:attribute:: remote_buffer_freq :value: 1 .. py:attribute:: remote_manager_rank :value: None .. py:method:: sample(n_samples) Samples training objects from the buffer. :param n_samples: The number of items to sample. :returns: A sampled Trajectories, Transitions, or StatesContainer. .. py:method:: save(path) Saves the buffer to a single ``.pt`` file. :param path: File path (e.g. ``"replay_buffer.pt"``). .. py:attribute:: training_container :type: ContainerUnion | None :value: None .. py:class:: TerminatingStateBuffer(env, capacity = 1000, **kwargs) Bases: :py:obj:`ReplayBuffer` A replay buffer for storing terminating states. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (StatesContainer). .. py:method:: add(training_container) Adds a training container to the buffer. The type of the training container is dynamically set based on the type of the first added container. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: training_container