gfn.utils.graphs

Classes

GeometricBatch

A batch of graphs.

Functions

compare_data_objects(a, b)

Compare two Data objects along the main fields.

data_share_storage(a, b)

True ⇢ every tensor attribute in a points to the same storage in b.

from_edge_indices(ei0, ei1, n_nodes, is_directed)

Return the index (or indices) corresponding to the provided edge(s).

get_edge_indices(n_nodes, is_directed, device)

Get the source and target node indices for the edges.

graph_states_share_storage(a, b)

Helper function to check if two GraphStates objects share storage.

hash_graph(data, directed)

Hash a PyG Data object (edge_index, edge_attr, x).

Module Contents

class gfn.utils.graphs.GeometricBatch

Bases: torch_geometric.data.Batch

A batch of graphs.

This class extends torch_geometric.data.Batch to support extending a batch with another batch, and to support stacking a list of Data objects into a single batch.

tensor

The underlying torch_geometric.data.Data object.

batch_shape

The shape of the batch.

batch_ptrs

A tensor of pointers to the start of each graph in the batch.

extend(other)

Extends the current batch with another batch.

Parameters:

other (GeometricBatch) – The batch to extend with.

Return type:

None

classmethod stack(data_list)

Stacks a list of Data objects into a single GeometricBatch.

Parameters:

data_list (list[torch_geometric.data.Data]) – A list of Data objects to stack.

Returns:

A new GeometricBatch containing the stacked graphs.

Return type:

GeometricBatch

gfn.utils.graphs.compare_data_objects(a, b)

Compare two Data objects along the main fields.

Parameters:
  • a (torch_geometric.data.Data)

  • b (torch_geometric.data.Data)

Return type:

bool

gfn.utils.graphs.data_share_storage(a, b)

True ⇢ every tensor attribute in a points to the same storage in b.

Parameters:
  • a (torch_geometric.data.Data) – The first Data object.

  • b (torch_geometric.data.Data) – The second Data object.

Returns:

True if every tensor attribute in a points to the same storage in b, False otherwise.

Return type:

bool

gfn.utils.graphs.from_edge_indices(ei0, ei1, n_nodes, is_directed)

Return the index (or indices) corresponding to the provided edge(s).

This is the inverse operation of get_edge_indices(). Given the source- and target-node indices of one or several edges, this function returns the position of each edge in the enumeration produced by get_edge_indices for the same n_nodes/is_directed setting.

The enumeration rules are the same as in get_edge_indices:

  1. is_directed = False → only the strict upper–triangular part (i < j) is enumerated using torch.triu_indices with offset=1.

  2. is_directed = True → the strict upper part is enumerated first, followed by the strict lower part (i > j).

Parameters:
  • ei1 (int | torch.Tensor) – Source- and target-node indices. They can be Python int or matching-shape tensors. If undirected, the orientation is ignored ((i, j) and (j, i) map to the same index).

  • n_nodes (int) – Number of nodes in the graph.

  • is_directed (bool) – Whether the graph is directed.

  • ei0 (int | torch.Tensor)

  • ei1

Returns:

The position(s) of each edge in the ordering returned by get_edge_indices.

Return type:

int | torch.Tensor

gfn.utils.graphs.get_edge_indices(n_nodes, is_directed, device)

Get the source and target node indices for the edges.

Parameters:
  • n_nodes (int) – The number of nodes in the graph.

  • is_directed (bool) – Whether the graph is directed.

  • device (torch.device) – The device to run the computation on.

Returns:

A tuple of two tensors, the source and target node indices.

Return type:

tuple[torch.Tensor, torch.Tensor]

gfn.utils.graphs.graph_states_share_storage(a, b)

Helper function to check if two GraphStates objects share storage.

Returns:

True if any tensor storage is shared between the two GraphStates.

Parameters:
Return type:

bool

gfn.utils.graphs.hash_graph(data, directed)

Hash a PyG Data object (edge_index, edge_attr, x). Produces the same hash for graphs that are element-wise identical.

Parameters:

directed (bool)

Return type:

str