gfn.utils.graphs ================ .. py:module:: gfn.utils.graphs Classes ------- .. autoapisummary:: gfn.utils.graphs.GeometricBatch Functions --------- .. autoapisummary:: gfn.utils.graphs.compare_data_objects gfn.utils.graphs.data_share_storage gfn.utils.graphs.from_edge_indices gfn.utils.graphs.get_edge_indices gfn.utils.graphs.graph_states_share_storage gfn.utils.graphs.hash_graph Module Contents --------------- .. py:class:: GeometricBatch Bases: :py:obj:`torch_geometric.data.Batch` A batch of graphs. This class extends `torch_geometric.data.Batch` to support extending a batch with another batch, and to support stacking a list of `Data` objects into a single batch. .. attribute:: tensor The underlying `torch_geometric.data.Data` object. .. attribute:: batch_shape The shape of the batch. .. attribute:: batch_ptrs A tensor of pointers to the start of each graph in the batch. .. py:method:: extend(other) Extends the current batch with another batch. :param other: The batch to extend with. .. py:method:: stack(data_list) :classmethod: Stacks a list of `Data` objects into a single `GeometricBatch`. :param data_list: A list of `Data` objects to stack. :returns: A new `GeometricBatch` containing the stacked graphs. .. py:function:: compare_data_objects(a, b) Compare two Data objects along the main fields. .. py:function:: data_share_storage(a, b) True ⇢ every tensor attribute in `a` points to the same storage in `b`. :param a: The first Data object. :param b: The second Data object. :returns: True if every tensor attribute in `a` points to the same storage in `b`, False otherwise. .. py:function:: from_edge_indices(ei0, ei1, n_nodes, is_directed) Return the index (or indices) corresponding to the provided edge(s). This is the inverse operation of :func:`get_edge_indices`. Given the source- and target-node indices of one or several edges, this function returns the *position* of each edge in the enumeration produced by ``get_edge_indices`` for the same ``n_nodes``/``is_directed`` setting. The enumeration rules are the same as in ``get_edge_indices``: 1. ``is_directed = False`` → only the strict upper–triangular part (``i < j``) is enumerated using ``torch.triu_indices`` with ``offset=1``. 2. ``is_directed = True`` → the strict upper part is enumerated first, followed by the strict lower part (``i > j``). :param ei0 / ei1: Source- and target-node indices. They can be Python ``int`` or *matching-shape* tensors. If undirected, the orientation is ignored (``(i, j)`` and ``(j, i)`` map to the same index). :param n_nodes: Number of nodes in the graph. :param is_directed: Whether the graph is directed. :returns: The position(s) of each edge in the ordering returned by ``get_edge_indices``. :rtype: int | torch.Tensor .. py:function:: get_edge_indices(n_nodes, is_directed, device) Get the source and target node indices for the edges. :param n_nodes: The number of nodes in the graph. :param is_directed: Whether the graph is directed. :param device: The device to run the computation on. :returns: A tuple of two tensors, the source and target node indices. .. py:function:: graph_states_share_storage(a, b) Helper function to check if two GraphStates objects share storage. :returns: True if *any* tensor storage is shared between the two GraphStates. .. py:function:: hash_graph(data, directed) Hash a PyG `Data` object (edge_index, edge_attr, x). Produces the same hash for graphs that are element-wise identical.