tutorials.examples.train_graph_ring¶
Train a GFlowNet to generate ring graphs.
This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex has exactly two neighbors and the edges form a single cycle containing all vertices. The environment supports both directed and undirected ring generation.
This problem has a number of modes that grows factorially with the number of nodes. This makes learning from scratch very difficult on even relatively small graphs.
For usage see train_graph_ring.py -h
Key components: - RingGraphBuilding: Environment for building ring graphs - RingPolicyModule: GNN-based policy network for predicting actions - directed_reward/undirected_reward: Reward functions for validating ring structures.
Attributes¶
Classes¶
This function evaluates if a graph forms a valid ring (directed or |
Functions¶
|
|
|
|
|
Main execution for training a GFlowNet to generate ring graphs. |
|
Visualize a batch of graph states as ring structures. |
Module Contents¶
- class tutorials.examples.train_graph_ring.RingReward(directed, reward_val=100.0, eps_val=1e-06, device=torch.device('cpu'))¶
Bases:
object- This function evaluates if a graph forms a valid ring (directed or
undirected cycle).
- Parameters:
directed (bool) – Whether the graph is directed.
reward_val (float) – The reward for valid directed rings.
eps_val (float) – The reward for invalid structures.
device (torch.device)
- Returns:
A tensor of rewards with the same batch shape as states
- __call__(states)¶
- Parameters:
states (gfn.states.GraphStates)
- Return type:
torch.Tensor
- device¶
- directed¶
- directed_reward(states)¶
Compute reward for directed ring graphs.
This function evaluates if a graph forms a valid directed ring (cycle). A valid directed ring must satisfy these conditions: 1. Each node must have exactly one outgoing edge (row sum = 1 in
adjacency matrix).
- Each node must have exactly one incoming edge (column sum = 1 in
adjacency matrix).
Following the edges must form a single cycle that includes all nodes.
- Parameters:
states (gfn.states.GraphStates) – A batch of graph states to evaluate.
- Returns:
A tensor of rewards with the same batch shape as states.
- Return type:
torch.Tensor
- eps_val = 1e-06¶
- reward_val = 100.0¶
- undirected_reward(states)¶
Compute reward for undirected ring graphs.
This function evaluates if a graph forms a valid undirected ring (cycle). A valid undirected ring must satisfy these conditions: 1. Each node must have exactly two neighbors (degree = 2) 2. The graph must form a single connected cycle including all nodes.
The algorithm: 1. Checks that all nodes have degree 2 2. Performs a traversal starting from node 0, following edges 3. Checks if the traversal visits all nodes and returns to start
- Parameters:
states (gfn.states.GraphStates) – A batch of graph states to evaluate
- Returns:
A tensor of rewards with the same batch shape as states
- Return type:
torch.Tensor
- tutorials.examples.train_graph_ring.init_env(n_nodes, directed, device)¶
- Parameters:
n_nodes (int)
directed (bool)
device (torch.device)
- Return type:
- tutorials.examples.train_graph_ring.init_gflownet(num_nodes, directed, use_gnn, embedding_dim, num_conv_layers, num_edge_classes, device)¶
- Parameters:
num_nodes (int)
directed (bool)
use_gnn (bool)
embedding_dim (int)
num_conv_layers (int)
num_edge_classes (int)
device (torch.device)
- Return type:
- tutorials.examples.train_graph_ring.main(args)¶
Main execution for training a GFlowNet to generate ring graphs.
This script demonstrates the complete workflow of training a GFlowNet to generate valid ring structures in both directed and undirected settings.
For usage see train_graph_ring.py -h
- The script performs the following steps:
Initialize the environment and policy networks.
Train the GFlowNet using trajectory balance.
Visualize sample generated graphs.
- Parameters:
args (argparse.Namespace)
- tutorials.examples.train_graph_ring.parser¶
- tutorials.examples.train_graph_ring.render_states(states, state_evaluator, directed)¶
Visualize a batch of graph states as ring structures.
This function creates a matplotlib visualization of graph states, rendering them as circular layouts with nodes positioned evenly around a circle. For directed graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines.
The visualization includes: - Circular positioning of nodes - Drawing edges between connected nodes - Displaying the reward value for each graph
- Parameters:
states (gfn.states.GraphStates) – A batch of graphs to visualize
state_evaluator (callable) – Function to compute rewards for each graph
directed (bool) – Whether to render directed or undirected edges