gfn.utils.distributions ======================= .. py:module:: gfn.utils.distributions Classes ------- .. autoapisummary:: gfn.utils.distributions.GraphActionDistribution gfn.utils.distributions.IsotropicGaussian gfn.utils.distributions.UnsqueezedCategorical Module Contents --------------- .. py:class:: GraphActionDistribution(logits = None, probs = None, is_backward = False, debug = False) Bases: :py:obj:`torch.distributions.Distribution` A mixture of categorical distributions for graph actions. This class is used to sample graph actions and compute their log probabilities. A graph action is a tuple of (action_type, node_class, edge_class, edge_index). The distribution of each component of the tuple is a categorical distribution. The components are conditionally dependent on the action_type. - If the action_type is ADD_NODE, then the node_class is sampled from a categorical distribution. - If the action_type is ADD_EDGE, then the edge_class and edge_index are sampled from categorical distributions. - If the action_type is EXIT, then no other components are sampled. .. py:attribute:: debug :value: False .. py:attribute:: dists .. py:attribute:: is_backward :value: False .. py:method:: log_prob(sample) Returns the log probabilities for a batch of action samples. Note that as we are using hierarchical sampling, the log_prob is the sum of the log_probs of the individual components. It is one of: - log_prob = p(action_type=add_node) + p(node_class) - log_prob = p(action_type=add_edge) + p(edge_class) + p(edge_index) - log_prob = p(action_type=remove_node) + p(node_index) - log_prob = p(action_type=remove_edge) + p(edge_index) - log_prob = p(action_type=exit) :param sample: A tensor of shape (*sample_shape, *batch_shape, 4) containing action samples, where the last dimension is the action type, node class, edge class, and edge index. :returns: A tensor of shape (*sample_shape, *batch_shape) containing the log probabilities for each sample. .. py:method:: sample(sample_shape=torch.Size()) Samples from the distribution. :param sample_shape: The shape of the sample. Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 4). .. py:class:: IsotropicGaussian(loc, scale, actions_detach = True) Bases: :py:obj:`torch.distributions.Distribution` Isotropic Gaussian distribution. This class is used to sample from and compute the log probabilities of isotropic Gaussian distributions, given the mean (loc) and std (scale) of the distribution. This is used primarily in the diffusion samplers. .. attribute:: loc The mean of the Gaussian distribution (shape: (*batch_shape, s_dim)) .. attribute:: scale The scale of the Gaussian distribution (shape: (*batch_shape, 1)) .. attribute:: actions_detach Whether to detach the actions from the graph. .. py:attribute:: actions_detach :value: True .. py:attribute:: loc .. py:method:: log_prob(actions) .. py:method:: sample(sample_shape = torch.Size()) .. py:attribute:: scale .. py:class:: UnsqueezedCategorical(probs=None, logits=None, validate_args=None, debug=False) Bases: :py:obj:`torch.distributions.Categorical` A `torch.distributions.Categorical` that unsqueezes the last dimension. This is useful for discrete environments that have an action shape of (1,), as the samples will have a shape of (batch_size, 1) instead of (batch_size,). .. py:attribute:: debug :value: False .. py:method:: log_prob(sample) Returns the log probabilities of an unsqueezed sample. :param sample: The sample of for which to compute the log probabilities. Returns the log probabilities of the sample as a tensor of shape (*sample_shape, *batch_shape). .. py:method:: sample(sample_shape=torch.Size()) Sample actions with an unsqueezed final dimension. :param sample_shape: The shape of the sample. Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 1).