gfn.utils.graphs¶
Classes¶
A batch of graphs. |
Functions¶
|
Compare two Data objects along the main fields. |
|
True ⇢ every tensor attribute in a points to the same storage in b. |
|
Return the index (or indices) corresponding to the provided edge(s). |
|
Get the source and target node indices for the edges. |
Helper function to check if two GraphStates objects share storage. |
|
|
Hash a PyG Data object (edge_index, edge_attr, x). |
Module Contents¶
- class gfn.utils.graphs.GeometricBatch¶
Bases:
torch_geometric.data.BatchA batch of graphs.
This class extends torch_geometric.data.Batch to support extending a batch with another batch, and to support stacking a list of Data objects into a single batch.
- tensor¶
The underlying torch_geometric.data.Data object.
- batch_shape¶
The shape of the batch.
- batch_ptrs¶
A tensor of pointers to the start of each graph in the batch.
- extend(other)¶
Extends the current batch with another batch.
- Parameters:
other (GeometricBatch) – The batch to extend with.
- Return type:
None
- classmethod stack(data_list)¶
Stacks a list of Data objects into a single GeometricBatch.
- Parameters:
data_list (list[torch_geometric.data.Data]) – A list of Data objects to stack.
- Returns:
A new GeometricBatch containing the stacked graphs.
- Return type:
- gfn.utils.graphs.compare_data_objects(a, b)¶
Compare two Data objects along the main fields.
- Parameters:
a (torch_geometric.data.Data)
b (torch_geometric.data.Data)
- Return type:
bool
True ⇢ every tensor attribute in a points to the same storage in b.
- Parameters:
a (torch_geometric.data.Data) – The first Data object.
b (torch_geometric.data.Data) – The second Data object.
- Returns:
True if every tensor attribute in a points to the same storage in b, False otherwise.
- Return type:
bool
- gfn.utils.graphs.from_edge_indices(ei0, ei1, n_nodes, is_directed)¶
Return the index (or indices) corresponding to the provided edge(s).
This is the inverse operation of
get_edge_indices(). Given the source- and target-node indices of one or several edges, this function returns the position of each edge in the enumeration produced byget_edge_indicesfor the samen_nodes/is_directedsetting.The enumeration rules are the same as in
get_edge_indices:is_directed = False→ only the strict upper–triangular part (i < j) is enumerated usingtorch.triu_indiceswithoffset=1.is_directed = True→ the strict upper part is enumerated first, followed by the strict lower part (i > j).
- Parameters:
ei1 (int | torch.Tensor) – Source- and target-node indices. They can be Python
intor matching-shape tensors. If undirected, the orientation is ignored ((i, j)and(j, i)map to the same index).n_nodes (int) – Number of nodes in the graph.
is_directed (bool) – Whether the graph is directed.
ei0 (int | torch.Tensor)
ei1
- Returns:
The position(s) of each edge in the ordering returned by
get_edge_indices.- Return type:
int | torch.Tensor
- gfn.utils.graphs.get_edge_indices(n_nodes, is_directed, device)¶
Get the source and target node indices for the edges.
- Parameters:
n_nodes (int) – The number of nodes in the graph.
is_directed (bool) – Whether the graph is directed.
device (torch.device) – The device to run the computation on.
- Returns:
A tuple of two tensors, the source and target node indices.
- Return type:
tuple[torch.Tensor, torch.Tensor]
Helper function to check if two GraphStates objects share storage.
- Returns:
True if any tensor storage is shared between the two GraphStates.
- Parameters:
- Return type:
bool
- gfn.utils.graphs.hash_graph(data, directed)¶
Hash a PyG Data object (edge_index, edge_attr, x). Produces the same hash for graphs that are element-wise identical.
- Parameters:
directed (bool)
- Return type:
str