gfn.containers ============== .. py:module:: gfn.containers Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/gfn/containers/base/index /autoapi/gfn/containers/message/index /autoapi/gfn/containers/replay_buffer/index /autoapi/gfn/containers/replay_buffer_manager/index /autoapi/gfn/containers/states_container/index /autoapi/gfn/containers/trajectories/index /autoapi/gfn/containers/transitions/index Classes ------- .. autoapisummary:: gfn.containers.Container gfn.containers.NormBasedDiversePrioritizedReplayBuffer gfn.containers.ReplayBuffer gfn.containers.StatesContainer gfn.containers.TerminatingStateBuffer gfn.containers.Trajectories gfn.containers.Transitions Package Contents ---------------- .. py:class:: Container Bases: :py:obj:`abc.ABC` Base class for state containers (states, transitions, or trajectories). .. py:method:: __getitem__(index) :abstractmethod: Returns a subset of the container based on the provided index. :param index: An integer, slice, tuple, sequence of indices or booleans, or a torch.Tensor specifying which elements to select. :returns: A new container containing the selected elements and associated data. .. py:method:: __len__() :abstractmethod: Returns the number of elements in the container. :returns: The number of elements in the container. .. py:property:: device :type: torch.device :abstractmethod: The device on which the container is stored. :returns: The device on which the container is stored. .. py:method:: extend(other) :abstractmethod: Extends the current container with elements from another container object. :param other: The other container whose elements will be added. .. py:method:: from_tensordict(env, td) :classmethod: :abstractmethod: Reconstruct a container from a TensorDict. :param env: The environment needed to reconstruct States/Actions. :param td: The TensorDict produced by :meth:`to_tensordict`. :returns: A new container instance. .. py:property:: has_log_probs :type: bool Whether the container has log probabilities. :returns: True if log probabilities are present and non-empty, False otherwise. .. py:method:: load(env, path) :classmethod: Loads a container from a ``.pt`` file saved by :meth:`save`. :param env: The environment needed to reconstruct States/Actions. :param path: File path to the saved container. :returns: A new container instance. .. py:property:: log_rewards :type: torch.Tensor :abstractmethod: The log rewards associated with the container. :returns: The log rewards tensor. .. py:method:: sample(n_samples) Randomly samples a subset of elements from the container. :param n_samples: The number of elements to sample. :returns: A new container with the sampled elements. .. py:method:: save(path) Saves the container to a single ``.pt`` file. :param path: File path (e.g. ``"trajectories.pt"``). .. py:property:: terminating_states :type: gfn.states.States :abstractmethod: The last (terminating) states of the container. :returns: The terminating states. .. py:method:: to_tensordict() :abstractmethod: Serialize the container's data into a TensorDict. :returns: A TensorDict containing all tensor data and scalar metadata. The ``env`` reference is not included; it must be supplied when reconstructing via :meth:`from_tensordict`. .. py:class:: NormBasedDiversePrioritizedReplayBuffer(env, capacity = 1000, cutoff_distance = 0.0, p_norm_distance = 1.0, remote_manager_rank = None, remote_buffer_freq = 1) Bases: :py:obj:`ReplayBuffer` A replay buffer with diversity-based prioritization. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object. .. attribute:: prioritized_capacity Whether to use prioritized capacity (keep highest-reward items). This is set to True by default. .. attribute:: prioritized_sampling Whether to sample items with probability proportional to their reward. .. attribute:: cutoff_distance Threshold used to determine whether a new terminating state is different enough from those already in the buffer. .. attribute:: p_norm_distance p-norm value for distance calculation (used in torch.cdist). .. py:method:: _diversity_repr(container) :staticmethod: Returns the tensor used for pairwise distance in diversity filtering. For conditional GFNs, concatenates conditions with the state tensor so that identical states under different conditions are treated as distinct. .. py:method:: add(training_container) Adds a training object to the buffer with diversity-based prioritization. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: cutoff_distance :value: 0.0 .. py:attribute:: p_norm_distance :value: 1.0 .. py:class:: ReplayBuffer(env, capacity = 1000, prioritized_capacity = False, prioritized_sampling = False, remote_manager_rank = None, remote_buffer_freq = 1) A replay buffer for storing containers. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (Trajectories, Transitions, or StatesContainer). This is dynamically set based on the type of the first added object. .. attribute:: prioritized_capacity Whether to use prioritized capacity (keep highest-reward items). .. attribute:: prioritized_sampling Whether to sample items with probability proportional to their reward. .. py:method:: __len__() Returns the number of items in the buffer. :returns: The number of items in the buffer. .. py:method:: __repr__() Returns a string representation of the ReplayBuffer. :returns: A string summary of the buffer. .. py:attribute:: _add_counter :value: 0 .. py:method:: _add_objs(training_container) Adds a training object to the buffer, handling the capacity. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: _is_full :value: False .. py:method:: _send_objs(training_container) Sends a training container to the remote manager. .. py:method:: add(training_container) Adds a training container to the buffer. The type of the training container is dynamically set based on the type of the first added container. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: capacity :value: 1000 .. py:property:: device :type: torch.device The device on which the buffer's data is stored. :returns: The device object of the buffer's contents. .. py:attribute:: env .. py:method:: initialize(training_container) Initializes the buffer with the type of the first added object. :param training_container: The initial Trajectories, Transitions, or StatesContainer object to set the buffer type. .. py:method:: load(path) Loads buffer contents from a ``.pt`` file saved by :meth:`save`. :param path: File path to the saved buffer. .. py:attribute:: pending_container :type: ContainerUnion | None :value: None .. py:attribute:: prioritized_capacity :value: False .. py:attribute:: prioritized_sampling :value: False .. py:attribute:: remote_buffer_freq :value: 1 .. py:attribute:: remote_manager_rank :value: None .. py:method:: sample(n_samples) Samples training objects from the buffer. :param n_samples: The number of items to sample. :returns: A sampled Trajectories, Transitions, or StatesContainer. .. py:method:: save(path) Saves the buffer to a single ``.pt`` file. :param path: File path (e.g. ``"replay_buffer.pt"``). .. py:attribute:: training_container :type: ContainerUnion | None :value: None .. py:class:: StatesContainer(env, states = None, is_terminating = None, log_rewards = None) Bases: :py:obj:`gfn.containers.base.Container`, :py:obj:`Generic`\ [\ :py:obj:`StateType`\ ] Container for a batch of states (mainly used for FMGFlowNet). This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms. .. attribute:: env The environment where the states are defined. .. attribute:: states States with batch_shape (n_states,). .. attribute:: is_terminating Boolean tensor of shape (n_states,) indicating which states are terminating. .. attribute:: _log_rewards (Optional) Tensor of shape (n_states,) containing the log rewards for terminating states. .. py:method:: __getitem__(index) Returns a subset of the states along the batch dimension. :param index: Indices to select states. :returns: A new StatesContainer with the selected states and associated data. .. py:method:: __len__() Returns the number of states in the container. :returns: The number of states. .. py:method:: __repr__() Returns a string representation of the StatesContainer. :returns: A string summary of the container. .. py:attribute:: _log_rewards :value: None .. py:property:: device :type: torch.device The device on which the states are stored. :returns: The device object of the `self.states`. .. py:attribute:: env .. py:method:: extend(other) Extends this container with another StatesContainer object. :param Another StatesContainer to append.: .. py:method:: from_tensordict(env, td) :classmethod: Reconstruct a StatesContainer from a TensorDict. .. py:property:: intermediary_states :type: StateType The intermediary states (not initial states) of the StatesContainer. :returns: The intermediary states. .. py:attribute:: is_terminating .. py:property:: log_rewards :type: torch.Tensor The log rewards for all states. :returns: Log rewards tensor of shape (len(self.states),). Intermediate states have value -inf. .. note:: If not provided at initialization, log rewards are computed on demand for terminating states. .. py:attribute:: states .. py:property:: terminating_log_rewards :type: torch.Tensor The log rewards for terminating states only. :returns: The log rewards for terminating states. .. py:property:: terminating_states :type: StateType The last (terminating) states of the StatesContainer. :returns: The terminating states. .. py:method:: to_tensordict() Serialize the states container into a TensorDict. .. py:class:: TerminatingStateBuffer(env, capacity = 1000, **kwargs) Bases: :py:obj:`ReplayBuffer` A replay buffer for storing terminating states. .. attribute:: env The environment associated with the containers. .. attribute:: capacity The maximum number of items the buffer can hold. .. attribute:: training_container The buffer contents (StatesContainer). .. py:method:: add(training_container) Adds a training container to the buffer. The type of the training container is dynamically set based on the type of the first added container. :param training_container: The Trajectories, Transitions, or StatesContainer object to add. .. py:attribute:: training_container .. py:class:: Trajectories(env, states = None, actions = None, terminating_idx = None, is_backward = False, log_rewards = None, log_probs = None, estimator_outputs = None) Bases: :py:obj:`gfn.containers.base.Container` Container for complete trajectories (starting in $s_0$ and ending in $s_f$). Trajectories are represented as a States object with bi-dimensional batch shape. Actions are represented as an Actions object with bi-dimensional batch shape. The first dimension represents the time step, the second dimension represents the trajectory index. Because different trajectories may have different lengths, shorter trajectories are padded with the tensor representation of the terminal state ($s_f$ or $s_0$ depending on the direction of the trajectory), and actions is appended with dummy actions. The `terminating_idx` tensor represents the time step at which each trajectory ends. .. attribute:: env The environment where the states and actions are defined. .. attribute:: states States with batch_shape (max_length+1, batch_size). .. attribute:: actions Actions with batch_shape (max_length, batch_size). .. attribute:: terminating_idx Tensor of shape (batch_size,) indicating the time step at which each trajectory ends. .. attribute:: is_backward Whether the trajectories are backward or forward. When not is_backward, the `states` are ordered from initial to terminal states. When is_backward, the `states` are ordered from terminal to initial states. .. attribute:: _log_rewards (Optional) Tensor of shape (batch_size,) containing the log rewards of the trajectories. .. attribute:: log_probs (Optional) Tensor of shape (max_length, batch_size) indicating the log probabilities of the trajectories' actions. .. attribute:: estimator_outputs (Optional) Tensor of shape (max_length, batch_size, ...) containing outputs of a function approximator for each step. .. py:method:: __getitem__(index) Returns a subset of the trajectories along the batch dimension. :param index: Indices to select trajectories. :returns: A new Trajectories object with the selected trajectories and associated data. .. py:method:: __len__() Returns the number of trajectories in the container. :returns: The number of trajectories. .. py:method:: __repr__() Returns a string representation of the Trajectories container. :returns: A string summary of the trajectories. .. py:attribute:: _log_rewards :value: None .. py:attribute:: actions .. py:property:: batch_size :type: int The number of trajectories in the container. :returns: The number of trajectories. .. py:property:: device :type: torch.device The device on which the trajectories are stored. :returns: The device object of the `self.states`. .. py:attribute:: env .. py:attribute:: estimator_outputs :value: None .. py:method:: extend(other) Extends this Trajectories object with another Trajectories object. Extends along all attributes in turn (actions, states, terminating_idx, log_probs, log_rewards). :param other: Another Trajectories to append. .. py:method:: from_tensordict(env, td) :classmethod: Reconstruct Trajectories from a TensorDict. .. py:attribute:: is_backward :value: False .. py:attribute:: log_probs :value: None .. py:property:: log_rewards :type: torch.Tensor | None The log rewards for the trajectories. :returns: Log rewards tensor of shape (batch_size,). .. note:: If not provided at initialization, log rewards are computed on demand for terminating states. .. py:property:: max_length :type: int The maximum length of the trajectories in the container. :returns: The maximum trajectory length. .. py:property:: n_trajectories :type: int Deprecated alias for :attr:`batch_size`. .. py:method:: reverse_backward_trajectories() Returns a reversed version of the backward trajectories. .. py:attribute:: states .. py:attribute:: terminating_idx .. py:property:: terminating_states :type: gfn.states.States The terminating states of the trajectories. :returns: The terminating states. .. py:method:: to_states_container() Returns a StatesContainer object from the current Trajectories. :returns: A StatesContainer object with the same states, actions, and log_rewards as the current Trajectories. .. py:method:: to_tensordict() Serialize trajectories into a TensorDict. .. py:method:: to_transitions() Returns a Transitions object from the current Trajectories. :returns: A Transitions object with the same states, actions, and log_rewards as the current Trajectories. .. py:class:: Transitions(env, states = None, actions = None, is_terminating = None, next_states = None, is_backward = False, log_rewards = None, log_probs = None) Bases: :py:obj:`gfn.containers.base.Container` Container for a batch of transitions. This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties. .. attribute:: env The environment where the states and actions are defined. .. attribute:: states States with batch_shape (n_transitions,). .. attribute:: actions Actions with batch_shape (n_transitions,). The actions make the transitions from the `states` to the `next_states`. .. attribute:: is_terminating Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action. .. attribute:: next_states States with batch_shape (n_transitions,). .. attribute:: is_backward Whether the transitions are backward transitions. When not is_backward, the `states` are the parents of the transitions and the `next_states` are the children. When is_backward, the `states` are the children of the transitions and the `next_states` are the parents. .. attribute:: _log_rewards (Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions. .. attribute:: log_probs (Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions. .. py:method:: __getitem__(index) Returns a subset of the transitions along the batch dimension. :param index: Indices to select transitions. :returns: A new Transitions object with the selected transitions and associated data. .. py:method:: __len__() Returns the number of transitions in the container. :returns: The number of transitions. .. py:method:: __repr__() Returns a string representation of the Transitions container. :returns: A string summary of the transitions. .. py:attribute:: _log_rewards :value: None .. py:attribute:: actions .. py:property:: all_log_rewards :type: torch.Tensor A helper method to compute the log rewards for all transitions This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss. :returns: Log rewards tensor of shape (n_transitions, 2) for the transitions. .. py:property:: device :type: torch.device The device on which the transitions are stored. :returns: The device object of the `self.states`. .. py:attribute:: env .. py:method:: extend(other) Extends this Transitions object with another Transitions object. :param Another Transitions object to append.: .. py:method:: from_tensordict(env, td) :classmethod: Reconstruct Transitions from a TensorDict. .. py:attribute:: is_backward :value: False .. py:attribute:: is_terminating .. py:attribute:: log_probs :value: None .. py:property:: log_rewards :type: torch.Tensor | None The log rewards for the transitions. :returns: Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf. .. note:: If not provided at initialization, log rewards are computed on demand for terminating transitions. .. py:property:: n_transitions :type: int The number of transitions in the container. :returns: The number of transitions. .. py:attribute:: next_states .. py:attribute:: states .. py:property:: terminating_states :type: gfn.states.States The terminating states of the transitions. :returns: The terminating states. .. py:method:: to_tensordict() Serialize transitions into a TensorDict.