tutorials.examples.train_box ============================ .. py:module:: tutorials.examples.train_box .. autoapi-nested-parse:: The goal of this script is to train a GFlowNet on the Box environment using Cartesian per-dimension increments. Example usage: python train_box.py --delta 0.25 --tied --loss TB python train_box.py --delta 0.1 --loss DB --n_components 5 Based on results from: [A theory of continuous generative flow networks](https://arxiv.org/abs/2301.12594) Attributes ---------- .. autoapisummary:: tutorials.examples.train_box.DEFAULT_SEED tutorials.examples.train_box.parser Functions --------- .. autoapisummary:: tutorials.examples.train_box.estimate_jsd tutorials.examples.train_box.get_test_states tutorials.examples.train_box.main tutorials.examples.train_box.plot_trajectories tutorials.examples.train_box.sample_from_reward Module Contents --------------- .. py:data:: DEFAULT_SEED :type: int :value: 4444 .. py:function:: estimate_jsd(kde1, kde2) Estimate Jensen-Shannon divergence between two distributions defined by KDEs :returns: A float value of the estimated JSD .. py:function:: get_test_states(n = 100, maxi = 1.0) Create a list of states from [0, 1]^2 by discretizing it into n x n grid. :returns: A numpy array of shape (n^2, 2) containing the test states, .. py:function:: main(args) .. py:data:: parser .. py:function:: plot_trajectories(env, sampler, n_trajectories = 100, output_path = None, alpha = 0.1) Plot sampled trajectories on the Box environment. Each trajectory is plotted as a line from s0 to the terminal state, with transparency to visualize overlapping paths. :param env: The Box environment. :param sampler: The sampler to use for generating trajectories. :param n_trajectories: Number of trajectories to sample and plot. :param output_path: Path to save the output plot. If None, defaults to EXAMPLES_OUTPUTS / 'train_box_trajectories.png'. :param alpha: Transparency for each trajectory line. .. py:function:: sample_from_reward(env, n_samples) Samples states from the true reward distribution Implement rejection sampling, with proposal being uniform distribution in [0, 1]^2 :returns: A numpy array of shape (n_samples, 2) containing the sampled states