tutorials.examples.train_graph_ring =================================== .. py:module:: tutorials.examples.train_graph_ring .. autoapi-nested-parse:: Train a GFlowNet to generate ring graphs. This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex has exactly two neighbors and the edges form a single cycle containing all vertices. The environment supports both directed and undirected ring generation. This problem has a number of modes that grows factorially with the number of nodes. This makes learning from scratch very difficult on even relatively small graphs. For usage see train_graph_ring.py -h Key components: - `RingGraphBuilding`: Environment for building ring graphs - `RingPolicyModule`: GNN-based policy network for predicting actions - `directed_reward`/`undirected_reward`: Reward functions for validating ring structures. Attributes ---------- .. autoapisummary:: tutorials.examples.train_graph_ring.parser Classes ------- .. autoapisummary:: tutorials.examples.train_graph_ring.RingReward Functions --------- .. autoapisummary:: tutorials.examples.train_graph_ring.init_env tutorials.examples.train_graph_ring.init_gflownet tutorials.examples.train_graph_ring.main tutorials.examples.train_graph_ring.render_states Module Contents --------------- .. py:class:: RingReward(directed, reward_val = 100.0, eps_val = 1e-06, device = torch.device('cpu')) Bases: :py:obj:`object` This function evaluates if a graph forms a valid ring (directed or undirected cycle). :param directed: Whether the graph is directed. :param reward_val: The reward for valid directed rings. :param eps_val: The reward for invalid structures. :returns: A tensor of rewards with the same batch shape as states .. py:method:: __call__(states) .. py:attribute:: device .. py:attribute:: directed .. py:method:: directed_reward(states) Compute reward for directed ring graphs. This function evaluates if a graph forms a valid directed ring (cycle). A valid directed ring must satisfy these conditions: 1. Each node must have exactly one outgoing edge (row sum = 1 in adjacency matrix). 2. Each node must have exactly one incoming edge (column sum = 1 in adjacency matrix). 3. Following the edges must form a single cycle that includes all nodes. :param states: A batch of graph states to evaluate. :returns: A tensor of rewards with the same batch shape as states. .. py:attribute:: eps_val :value: 1e-06 .. py:attribute:: reward_val :value: 100.0 .. py:method:: undirected_reward(states) Compute reward for undirected ring graphs. This function evaluates if a graph forms a valid undirected ring (cycle). A valid undirected ring must satisfy these conditions: 1. Each node must have exactly two neighbors (degree = 2) 2. The graph must form a single connected cycle including all nodes. The algorithm: 1. Checks that all nodes have degree 2 2. Performs a traversal starting from node 0, following edges 3. Checks if the traversal visits all nodes and returns to start :param states: A batch of graph states to evaluate :returns: A tensor of rewards with the same batch shape as states .. py:function:: init_env(n_nodes, directed, device) .. py:function:: init_gflownet(num_nodes, directed, use_gnn, embedding_dim, num_conv_layers, num_edge_classes, device) .. py:function:: main(args) Main execution for training a GFlowNet to generate ring graphs. This script demonstrates the complete workflow of training a GFlowNet to generate valid ring structures in both directed and undirected settings. For usage see train_graph_ring.py -h The script performs the following steps: 1. Initialize the environment and policy networks. 2. Train the GFlowNet using trajectory balance. 3. Visualize sample generated graphs. .. py:data:: parser .. py:function:: render_states(states, state_evaluator, directed) Visualize a batch of graph states as ring structures. This function creates a matplotlib visualization of graph states, rendering them as circular layouts with nodes positioned evenly around a circle. For directed graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines. The visualization includes: - Circular positioning of nodes - Drawing edges between connected nodes - Displaying the reward value for each graph :param states: A batch of graphs to visualize :param state_evaluator: Function to compute rewards for each graph :param directed: Whether to render directed or undirected edges