gfn.containers.states_container =============================== .. py:module:: gfn.containers.states_container Attributes ---------- .. autoapisummary:: gfn.containers.states_container.StateType Classes ------- .. autoapisummary:: gfn.containers.states_container.StatesContainer Module Contents --------------- .. py:data:: StateType .. py:class:: StatesContainer(env, states = None, is_terminating = None, log_rewards = None) Bases: :py:obj:`gfn.containers.base.Container`, :py:obj:`Generic`\ [\ :py:obj:`StateType`\ ] Container for a batch of states (mainly used for FMGFlowNet). This class manages a collection of states and their corresponding properties. It is mainly used for Flow Matching GFlowNet algorithms. .. attribute:: env The environment where the states are defined. .. attribute:: states States with batch_shape (n_states,). .. attribute:: is_terminating Boolean tensor of shape (n_states,) indicating which states are terminating. .. attribute:: _log_rewards (Optional) Tensor of shape (n_states,) containing the log rewards for terminating states. .. py:method:: __getitem__(index) Returns a subset of the states along the batch dimension. :param index: Indices to select states. :returns: A new StatesContainer with the selected states and associated data. .. py:method:: __len__() Returns the number of states in the container. :returns: The number of states. .. py:method:: __repr__() Returns a string representation of the StatesContainer. :returns: A string summary of the container. .. py:attribute:: _log_rewards :value: None .. py:property:: device :type: torch.device The device on which the states are stored. :returns: The device object of the `self.states`. .. py:attribute:: env .. py:method:: extend(other) Extends this container with another StatesContainer object. :param Another StatesContainer to append.: .. py:method:: from_tensordict(env, td) :classmethod: Reconstruct a StatesContainer from a TensorDict. .. py:property:: intermediary_states :type: StateType The intermediary states (not initial states) of the StatesContainer. :returns: The intermediary states. .. py:attribute:: is_terminating .. py:property:: log_rewards :type: torch.Tensor The log rewards for all states. :returns: Log rewards tensor of shape (len(self.states),). Intermediate states have value -inf. .. note:: If not provided at initialization, log rewards are computed on demand for terminating states. .. py:attribute:: states .. py:property:: terminating_log_rewards :type: torch.Tensor The log rewards for terminating states only. :returns: The log rewards for terminating states. .. py:property:: terminating_states :type: StateType The last (terminating) states of the StatesContainer. :returns: The terminating states. .. py:method:: to_tensordict() Serialize the states container into a TensorDict.