gfn.utils.distributions¶
Classes¶
A mixture of categorical distributions for graph actions. |
|
Isotropic Gaussian distribution. |
|
A torch.distributions.Categorical that unsqueezes the last dimension. |
Module Contents¶
- class gfn.utils.distributions.GraphActionDistribution(logits=None, probs=None, is_backward=False, debug=False)¶
Bases:
torch.distributions.DistributionA mixture of categorical distributions for graph actions.
This class is used to sample graph actions and compute their log probabilities. A graph action is a tuple of (action_type, node_class, edge_class, edge_index). The distribution of each component of the tuple is a categorical distribution. The components are conditionally dependent on the action_type.
- If the action_type is ADD_NODE, then the node_class is sampled from a
categorical distribution.
- If the action_type is ADD_EDGE, then the edge_class and edge_index are
sampled from categorical distributions.
If the action_type is EXIT, then no other components are sampled.
- Parameters:
logits (tensordict.TensorDict | None)
probs (tensordict.TensorDict | None)
is_backward (bool)
debug (bool)
- debug = False¶
- dists¶
- is_backward = False¶
- log_prob(sample)¶
Returns the log probabilities for a batch of action samples.
Note that as we are using hierarchical sampling, the log_prob is the sum of the log_probs of the individual components. It is one of:
log_prob = p(action_type=add_node) + p(node_class)
log_prob = p(action_type=add_edge) + p(edge_class) + p(edge_index)
log_prob = p(action_type=remove_node) + p(node_index)
log_prob = p(action_type=remove_edge) + p(edge_index)
log_prob = p(action_type=exit)
- Parameters:
sample (torch.Tensor) – A tensor of shape (*sample_shape, *batch_shape, 4) containing action samples, where the last dimension is the action type, node class, edge class, and edge index.
- Returns:
A tensor of shape (*sample_shape, *batch_shape) containing the log probabilities for each sample.
- Return type:
torch.Tensor
- class gfn.utils.distributions.IsotropicGaussian(loc, scale, actions_detach=True)¶
Bases:
torch.distributions.DistributionIsotropic Gaussian distribution.
This class is used to sample from and compute the log probabilities of isotropic Gaussian distributions, given the mean (loc) and std (scale) of the distribution. This is used primarily in the diffusion samplers.
- Parameters:
loc (torch.Tensor)
scale (torch.Tensor)
actions_detach (bool)
- actions_detach¶
Whether to detach the actions from the graph.
- actions_detach = True¶
- loc¶
- log_prob(actions)¶
- Parameters:
actions (torch.Tensor)
- Return type:
torch.Tensor
- sample(sample_shape=torch.Size())¶
- Parameters:
sample_shape (torch.Size)
- Return type:
torch.Tensor
- scale¶
- class gfn.utils.distributions.UnsqueezedCategorical(probs=None, logits=None, validate_args=None, debug=False)¶
Bases:
torch.distributions.CategoricalA torch.distributions.Categorical that unsqueezes the last dimension.
This is useful for discrete environments that have an action shape of (1,), as the samples will have a shape of (batch_size, 1) instead of (batch_size,).
- debug = False¶