gfn.containers.transitions ========================== .. py:module:: gfn.containers.transitions Classes ------- .. autoapisummary:: gfn.containers.transitions.Transitions Module Contents --------------- .. py:class:: Transitions(env, states = None, actions = None, is_terminating = None, next_states = None, is_backward = False, log_rewards = None, log_probs = None) Bases: :py:obj:`gfn.containers.base.Container` Container for a batch of transitions. This class manages a collection of transitions (triplet of states, actions, and next states) and their corresponding properties. .. attribute:: env The environment where the states and actions are defined. .. attribute:: states States with batch_shape (n_transitions,). .. attribute:: actions Actions with batch_shape (n_transitions,). The actions make the transitions from the `states` to the `next_states`. .. attribute:: is_terminating Boolean tensor of shape (n_transitions,) indicating whether the action is the exit action. .. attribute:: next_states States with batch_shape (n_transitions,). .. attribute:: is_backward Whether the transitions are backward transitions. When not is_backward, the `states` are the parents of the transitions and the `next_states` are the children. When is_backward, the `states` are the children of the transitions and the `next_states` are the parents. .. attribute:: _log_rewards (Optional) Tensor of shape (n_transitions,) containing the log rewards of the transitions. .. attribute:: log_probs (Optional) Tensor of shape (n_transitions,) containing the log probabilities of the actions. .. py:method:: __getitem__(index) Returns a subset of the transitions along the batch dimension. :param index: Indices to select transitions. :returns: A new Transitions object with the selected transitions and associated data. .. py:method:: __len__() Returns the number of transitions in the container. :returns: The number of transitions. .. py:method:: __repr__() Returns a string representation of the Transitions container. :returns: A string summary of the transitions. .. py:attribute:: _log_rewards :value: None .. py:attribute:: actions .. py:property:: all_log_rewards :type: torch.Tensor A helper method to compute the log rewards for all transitions This is applicable to environments where all states are terminating. This function evaluates the rewards for all transitions that do not end in the sink state. This is useful for the Modified Detailed Balance loss. :returns: Log rewards tensor of shape (n_transitions, 2) for the transitions. .. py:property:: device :type: torch.device The device on which the transitions are stored. :returns: The device object of the `self.states`. .. py:attribute:: env .. py:method:: extend(other) Extends this Transitions object with another Transitions object. :param Another Transitions object to append.: .. py:method:: from_tensordict(env, td) :classmethod: Reconstruct Transitions from a TensorDict. .. py:attribute:: is_backward :value: False .. py:attribute:: is_terminating .. py:attribute:: log_probs :value: None .. py:property:: log_rewards :type: torch.Tensor | None The log rewards for the transitions. :returns: Log rewards tensor of shape (n_transitions,). Non-terminating transitions have value -inf. .. note:: If not provided at initialization, log rewards are computed on demand for terminating transitions. .. py:property:: n_transitions :type: int The number of transitions in the container. :returns: The number of transitions. .. py:attribute:: next_states .. py:attribute:: states .. py:property:: terminating_states :type: gfn.states.States The terminating states of the transitions. :returns: The terminating states. .. py:method:: to_tensordict() Serialize transitions into a TensorDict.