tutorials.examples.train_with_example_modes¶
Train a GFlowNet to generate ring graphs using example modes.
This script demonstrates how to use example modes to warm-start gflownet exploration. We show this in the context of generating ring graphs, where the number of modes quickly grows with the number of nodes in the graph, making learning from scratch very difficult on even relatively small graphs. Here, we show how to use example modes to warm-start gflownet exploration allows us to learn a gflownet that can sample all modes.
For usage see train_with_example_modes.py -h
- The script performs the following steps:
Initialize the environment and policy networks.
- If using expert data, generates all possible ring graphs, and pre-fills the
replay buffer with 1/2 of their forward trajectories (found by computing the backward trajectories from the final states, then reversing them).
- Train the GFlowNet using trajectory balance, with each batch containing a
mix of 50% replay buffer and 50% gflownet samples.
- At the end of training we evaluate the GFlowNet’s ability to recover all
modes.
Optionally, we plot samples of generated graphs.
This tutorial uses the same environment as train_graph_ring.py.
Attributes¶
Functions¶
|
Count the number of unique modes in the final states that are in the mode_hashes. |
|
Generate all possible ring graphs for a given number of nodes using GraphActions. |
|
Main execution for training a GFlowNet to generate ring graphs. |
|
Compute the per-step decay multiplier y (γ) so that after |
Module Contents¶
- tutorials.examples.train_with_example_modes.count_recovered_modes(final_states, mode_hashes)¶
Count the number of unique modes in the final states that are in the mode_hashes.
- Parameters:
final_states (gfn.states.GraphStates)
mode_hashes (set[str])
- Return type:
int
- tutorials.examples.train_with_example_modes.generate_all_rings(n_nodes, device='cpu', max_rings=10000)¶
Generate all possible ring graphs for a given number of nodes using GraphActions.
- Parameters:
n_nodes (int) – Number of nodes in the graph
device (Union[str, torch.device, Literal['cpu', 'cuda']]) – Device to use for tensor operations (“cpu” or “cuda”)
max_rings (int) – Maximum number of rings to generate.
- Returns:
final_state is the GraphState representing a valid ring
actions is the list of GraphActions that build the ring
- Return type:
List of tuples (final_state, actions) where
- tutorials.examples.train_with_example_modes.main(args)¶
Main execution for training a GFlowNet to generate ring graphs.
For usage see train_with_example_modes.py -h
- The function performs the following steps:
Initialize the environment and policy networks.
If using expert data, generates all possible ring graphs, and pre-fills the replay buffer with 1/2 of their forward trajectories (found by computing the backward trajectories from the final states, then reversing them).
Train the GFlowNet using trajectory balance, with each batch containing a mix of 50% replay buffer and 50% gflownet samples.
At the end of training we evaluate the GFlowNet’s ability to recover all modes.
Optionally, we plot samples of generated graphs.
- Parameters:
args (argparse.Namespace)
- tutorials.examples.train_with_example_modes.parser¶
- tutorials.examples.train_with_example_modes.per_step_decay(num_steps, total_drop)¶
Compute the per-step decay multiplier y (γ) so that after
num_stepsscheduler steps the learning rate has been multiplied bytotal_drop.lr_final = lr_init * y**num_steps –> y = total_drop**(1/num_steps)
- Parameters:
num_steps (total number of scheduler.step() calls you will make)
total_drop (desired overall multiplier (e.g. 0.1 for a 10× drop))
- Returns:
y
- Return type:
float # per-step multiplier