tutorials.examples.train_graph_ring

Train a GFlowNet to generate ring graphs.

This example demonstrates training a GFlowNet to generate graphs that are rings - where each vertex has exactly two neighbors and the edges form a single cycle containing all vertices. The environment supports both directed and undirected ring generation.

This problem has a number of modes that grows factorially with the number of nodes. This makes learning from scratch very difficult on even relatively small graphs.

For usage see train_graph_ring.py -h

Key components: - RingGraphBuilding: Environment for building ring graphs - RingPolicyModule: GNN-based policy network for predicting actions - directed_reward/undirected_reward: Reward functions for validating ring structures.

Attributes

parser

Classes

RingReward

This function evaluates if a graph forms a valid ring (directed or

Functions

init_env(n_nodes, directed, device)

init_gflownet(num_nodes, directed, use_gnn, ...)

main(args)

Main execution for training a GFlowNet to generate ring graphs.

render_states(states, state_evaluator, directed)

Visualize a batch of graph states as ring structures.

Module Contents

class tutorials.examples.train_graph_ring.RingReward(directed, reward_val=100.0, eps_val=1e-06, device=torch.device('cpu'))

Bases: object

This function evaluates if a graph forms a valid ring (directed or

undirected cycle).

Parameters:
  • directed (bool) – Whether the graph is directed.

  • reward_val (float) – The reward for valid directed rings.

  • eps_val (float) – The reward for invalid structures.

  • device (torch.device)

Returns:

A tensor of rewards with the same batch shape as states

__call__(states)
Parameters:

states (gfn.states.GraphStates)

Return type:

torch.Tensor

device
directed
directed_reward(states)

Compute reward for directed ring graphs.

This function evaluates if a graph forms a valid directed ring (cycle). A valid directed ring must satisfy these conditions: 1. Each node must have exactly one outgoing edge (row sum = 1 in

adjacency matrix).

  1. Each node must have exactly one incoming edge (column sum = 1 in

    adjacency matrix).

  2. Following the edges must form a single cycle that includes all nodes.

Parameters:

states (gfn.states.GraphStates) – A batch of graph states to evaluate.

Returns:

A tensor of rewards with the same batch shape as states.

Return type:

torch.Tensor

eps_val = 1e-06
reward_val = 100.0
undirected_reward(states)

Compute reward for undirected ring graphs.

This function evaluates if a graph forms a valid undirected ring (cycle). A valid undirected ring must satisfy these conditions: 1. Each node must have exactly two neighbors (degree = 2) 2. The graph must form a single connected cycle including all nodes.

The algorithm: 1. Checks that all nodes have degree 2 2. Performs a traversal starting from node 0, following edges 3. Checks if the traversal visits all nodes and returns to start

Parameters:

states (gfn.states.GraphStates) – A batch of graph states to evaluate

Returns:

A tensor of rewards with the same batch shape as states

Return type:

torch.Tensor

tutorials.examples.train_graph_ring.init_env(n_nodes, directed, device)
Parameters:
  • n_nodes (int)

  • directed (bool)

  • device (torch.device)

Return type:

gfn.gym.graph_building.GraphBuildingOnEdges

tutorials.examples.train_graph_ring.init_gflownet(num_nodes, directed, use_gnn, embedding_dim, num_conv_layers, num_edge_classes, device)
Parameters:
  • num_nodes (int)

  • directed (bool)

  • use_gnn (bool)

  • embedding_dim (int)

  • num_conv_layers (int)

  • num_edge_classes (int)

  • device (torch.device)

Return type:

gfn.gflownet.trajectory_balance.TBGFlowNet

tutorials.examples.train_graph_ring.main(args)

Main execution for training a GFlowNet to generate ring graphs.

This script demonstrates the complete workflow of training a GFlowNet to generate valid ring structures in both directed and undirected settings.

For usage see train_graph_ring.py -h

The script performs the following steps:
  1. Initialize the environment and policy networks.

  2. Train the GFlowNet using trajectory balance.

  3. Visualize sample generated graphs.

Parameters:

args (argparse.Namespace)

tutorials.examples.train_graph_ring.parser
tutorials.examples.train_graph_ring.render_states(states, state_evaluator, directed)

Visualize a batch of graph states as ring structures.

This function creates a matplotlib visualization of graph states, rendering them as circular layouts with nodes positioned evenly around a circle. For directed graphs, edges are shown as arrows; for undirected graphs, edges are shown as lines.

The visualization includes: - Circular positioning of nodes - Drawing edges between connected nodes - Displaying the reward value for each graph

Parameters:
  • states (gfn.states.GraphStates) – A batch of graphs to visualize

  • state_evaluator (callable) – Function to compute rewards for each graph

  • directed (bool) – Whether to render directed or undirected edges