gfn.utils.distributions

Classes

GraphActionDistribution

A mixture of categorical distributions for graph actions.

IsotropicGaussian

Isotropic Gaussian distribution.

UnsqueezedCategorical

A torch.distributions.Categorical that unsqueezes the last dimension.

Module Contents

class gfn.utils.distributions.GraphActionDistribution(logits=None, probs=None, is_backward=False, debug=False)

Bases: torch.distributions.Distribution

A mixture of categorical distributions for graph actions.

This class is used to sample graph actions and compute their log probabilities. A graph action is a tuple of (action_type, node_class, edge_class, edge_index). The distribution of each component of the tuple is a categorical distribution. The components are conditionally dependent on the action_type.

  • If the action_type is ADD_NODE, then the node_class is sampled from a

    categorical distribution.

  • If the action_type is ADD_EDGE, then the edge_class and edge_index are

    sampled from categorical distributions.

  • If the action_type is EXIT, then no other components are sampled.

Parameters:
  • logits (tensordict.TensorDict | None)

  • probs (tensordict.TensorDict | None)

  • is_backward (bool)

  • debug (bool)

debug = False
dists
is_backward = False
log_prob(sample)

Returns the log probabilities for a batch of action samples.

Note that as we are using hierarchical sampling, the log_prob is the sum of the log_probs of the individual components. It is one of:

  • log_prob = p(action_type=add_node) + p(node_class)

  • log_prob = p(action_type=add_edge) + p(edge_class) + p(edge_index)

  • log_prob = p(action_type=remove_node) + p(node_index)

  • log_prob = p(action_type=remove_edge) + p(edge_index)

  • log_prob = p(action_type=exit)

Parameters:

sample (torch.Tensor) – A tensor of shape (*sample_shape, *batch_shape, 4) containing action samples, where the last dimension is the action type, node class, edge class, and edge index.

Returns:

A tensor of shape (*sample_shape, *batch_shape) containing the log probabilities for each sample.

Return type:

torch.Tensor

sample(sample_shape=torch.Size())

Samples from the distribution.

Parameters:

sample_shape – The shape of the sample.

Return type:

torch.Tensor

Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 4).

class gfn.utils.distributions.IsotropicGaussian(loc, scale, actions_detach=True)

Bases: torch.distributions.Distribution

Isotropic Gaussian distribution.

This class is used to sample from and compute the log probabilities of isotropic Gaussian distributions, given the mean (loc) and std (scale) of the distribution. This is used primarily in the diffusion samplers.

Parameters:
  • loc (torch.Tensor)

  • scale (torch.Tensor)

  • actions_detach (bool)

loc

The mean of the Gaussian distribution (shape: (*batch_shape, s_dim))

scale

The scale of the Gaussian distribution (shape: (*batch_shape, 1))

actions_detach

Whether to detach the actions from the graph.

actions_detach = True
loc
log_prob(actions)
Parameters:

actions (torch.Tensor)

Return type:

torch.Tensor

sample(sample_shape=torch.Size())
Parameters:

sample_shape (torch.Size)

Return type:

torch.Tensor

scale
class gfn.utils.distributions.UnsqueezedCategorical(probs=None, logits=None, validate_args=None, debug=False)

Bases: torch.distributions.Categorical

A torch.distributions.Categorical that unsqueezes the last dimension.

This is useful for discrete environments that have an action shape of (1,), as the samples will have a shape of (batch_size, 1) instead of (batch_size,).

debug = False
log_prob(sample)

Returns the log probabilities of an unsqueezed sample.

Parameters:

sample (torch.Tensor) – The sample of for which to compute the log probabilities.

Return type:

torch.Tensor

Returns the log probabilities of the sample as a tensor of shape (*sample_shape, *batch_shape).

sample(sample_shape=torch.Size())

Sample actions with an unsqueezed final dimension.

Parameters:

sample_shape – The shape of the sample.

Return type:

torch.Tensor

Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 1).