tutorials.examples.train_with_example_modes =========================================== .. py:module:: tutorials.examples.train_with_example_modes .. autoapi-nested-parse:: Train a GFlowNet to generate ring graphs using example modes. This script demonstrates how to use example modes to warm-start gflownet exploration. We show this in the context of generating ring graphs, where the number of modes quickly grows with the number of nodes in the graph, making learning from scratch very difficult on even relatively small graphs. Here, we show how to use example modes to warm-start gflownet exploration allows us to learn a gflownet that can sample all modes. For usage see train_with_example_modes.py -h The script performs the following steps: 1. Initialize the environment and policy networks. 2. If using expert data, generates all possible ring graphs, and pre-fills the replay buffer with 1/2 of their forward trajectories (found by computing the backward trajectories from the final states, then reversing them). 3. Train the GFlowNet using trajectory balance, with each batch containing a mix of 50% replay buffer and 50% gflownet samples. 4. At the end of training we evaluate the GFlowNet's ability to recover all modes. 5. Optionally, we plot samples of generated graphs. This tutorial uses the same environment as train_graph_ring.py. Attributes ---------- .. autoapisummary:: tutorials.examples.train_with_example_modes.parser Functions --------- .. autoapisummary:: tutorials.examples.train_with_example_modes.count_recovered_modes tutorials.examples.train_with_example_modes.generate_all_rings tutorials.examples.train_with_example_modes.main tutorials.examples.train_with_example_modes.per_step_decay Module Contents --------------- .. py:function:: count_recovered_modes(final_states, mode_hashes) Count the number of unique modes in the final states that are in the mode_hashes. .. py:function:: generate_all_rings(n_nodes, device = 'cpu', max_rings = 10000) Generate all possible ring graphs for a given number of nodes using GraphActions. :param n_nodes: Number of nodes in the graph :param device: Device to use for tensor operations ("cpu" or "cuda") :param max_rings: Maximum number of rings to generate. :returns: - final_state is the GraphState representing a valid ring - actions is the list of GraphActions that build the ring :rtype: List of tuples (final_state, actions) where .. py:function:: main(args) Main execution for training a GFlowNet to generate ring graphs. For usage see train_with_example_modes.py -h The function performs the following steps: 1. Initialize the environment and policy networks. 2. If using expert data, generates all possible ring graphs, and pre-fills the replay buffer with 1/2 of their forward trajectories (found by computing the backward trajectories from the final states, then reversing them). 3. Train the GFlowNet using trajectory balance, with each batch containing a mix of 50% replay buffer and 50% gflownet samples. 4. At the end of training we evaluate the GFlowNet's ability to recover all modes. 5. Optionally, we plot samples of generated graphs. .. py:data:: parser .. py:function:: per_step_decay(num_steps, total_drop) Compute the per-step decay multiplier y (γ) so that after ``num_steps`` scheduler steps the learning rate has been multiplied by ``total_drop``. lr_final = lr_init * y**num_steps --> y = total_drop**(1/num_steps) :param num_steps: :type num_steps: total number of scheduler.step() calls you will make :param total_drop: :type total_drop: desired overall multiplier (e.g. 0.1 for a 10× drop) :returns: **y** :rtype: float # per-step multiplier