gfn.gym.diffusion_sampling ========================== .. py:module:: gfn.gym.diffusion_sampling Attributes ---------- .. autoapisummary:: gfn.gym.diffusion_sampling.TargetEntry gfn.gym.diffusion_sampling.logger Classes ------- .. autoapisummary:: gfn.gym.diffusion_sampling.BaseTarget gfn.gym.diffusion_sampling.DiffusionSampling gfn.gym.diffusion_sampling.Funnel gfn.gym.diffusion_sampling.Grid25GaussianMixture gfn.gym.diffusion_sampling.ManyWell gfn.gym.diffusion_sampling.Posterior9of25GaussianMixture gfn.gym.diffusion_sampling.SimpleGaussianMixture Module Contents --------------- .. py:class:: BaseTarget(device, dim, n_gt_xs, seed = 0, plot_border = None) Bases: :py:obj:`abc.ABC` Base class for all target distributions for diffusion sampling. .. attribute:: device The device on which the target is stored. .. attribute:: dim The dimension of the target. .. attribute:: gt_xs The ground truth samples. .. attribute:: gt_xs_log_rewards The log rewards of the ground truth samples. .. py:method:: cached_sample(batch_size, seed = None) Cached sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples and the log rewards. .. py:attribute:: device .. py:attribute:: dim .. py:method:: grad_log_reward(x) Gradient of the log reward function. :param x: The input tensor. :returns: The gradient of the log reward function. .. py:method:: gt_logz() :abstractmethod: Log partition function of the target. :returns: The log partition function. .. py:method:: log_reward(x) :abstractmethod: Log reward function. :param x: The input tensor. :returns: The log rewards for the input tensor. .. py:attribute:: plot_border :value: None .. py:method:: sample(batch_size, seed = None) :abstractmethod: Sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 1000) :abstractmethod: Visualize the target. :param samples: The samples to visualize. :param show: Whether to show the plot. :param prefix: The prefix for the plot file name. :param output_dir: Directory to save the plot to. Required when ``show=False``; when *None* and ``show=False`` the figure is silently discarded. :param grid_width_n_points: The number of points along each axis of the visualization grid. :param max_n_samples: The maximum number of samples to visualize. .. py:class:: DiffusionSampling(target_str, target_kwargs, num_discretization_steps, device = torch.device('cpu'), debug = False) Bases: :py:obj:`gfn.env.Env` Diffusion sampling environment. .. attribute:: target The target distribution. .. attribute:: device The device to use. .. py:attribute:: DIFFUSION_TARGETS :type: dict[str, TargetEntry] .. py:method:: backward_step(states, actions) Backward step function for the SimpleGaussianMixtureModel environment. :param states: The current states. :param actions: The actions, which correspond to the changes to the states. :returns: The previous states. .. py:method:: density_metrics(fwd_log_pfs, fwd_log_pbs, fwd_log_rewards, log_Z_learned, bwd_log_pfs = None, bwd_log_pbs = None, bwd_log_rewards = None, gt_log_Z = None) :staticmethod: .. py:attribute:: dim .. py:attribute:: dt .. py:method:: is_action_valid(states, actions, backward = False) Check if the actions are valid. :param states: The current states. :param actions: The actions to check. :returns: True if the actions are valid, False otherwise. .. py:method:: list_available_targets() :classmethod: Return metadata about available targets and their default kwargs. This helper allows users to easily see which kwargs are provided by default for each alias. Note that accepted/required kwargs are determined by each target class's constructor signature. .. py:method:: log_reward(states) Log reward function for the DiffusionSampling environment. :param states: The current states. :returns: The log rewards for the input states. .. py:method:: make_states_class() Returns the States class for diffusion sampling. .. py:method:: step(states, actions) Step function for the SimpleGaussianMixtureModel environment. :param states: The current states. :param actions: The actions, which correspond to the changes to the states. :returns: The next states. .. py:attribute:: target .. py:class:: Funnel(dim = 10, std = 1.0, device = torch.device('cpu'), seed = 0) Bases: :py:obj:`BaseTarget` Neal's funnel distribution target. x0 ~ Normal(0, std^2), and for i >= 1: xi | x0 ~ Normal(0, exp(x0)). :param dim: Total dimensionality (x0 plus dim-1 conditional coordinates). :param std: Standard deviation for the marginal prior on x0. :param device: Torch device. :param seed: RNG seed. .. py:attribute:: dist_dominant .. py:method:: gt_logz() Log partition function of the target. :returns: The log partition function. .. py:method:: log_reward(x) Log-density of Neal's funnel distribution. Returns log p(x0) + sum_i log p(xi | x0), i=1..dim-1. .. py:method:: sample(batch_size, seed = None) Sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 500) Visualize only supported for 2D (x0, x1). .. py:class:: Grid25GaussianMixture(device, dim = 2, scale = math.sqrt(0.3), plot_border = 15.0, seed = 0) Bases: :py:obj:`BaseTarget` Fixed 5x5 Gaussian mixture prior used for RTB demos. .. py:attribute:: gmm .. py:method:: gt_logz() Log partition function of the target. :returns: The log partition function. .. py:attribute:: locs .. py:method:: log_reward(x) Log reward function. :param x: The input tensor. :returns: The log rewards for the input tensor. .. py:method:: sample(batch_size, seed = None) Sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 1000) Visualize the target. :param samples: The samples to visualize. :param show: Whether to show the plot. :param prefix: The prefix for the plot file name. :param output_dir: Directory to save the plot to. Required when ``show=False``; when *None* and ``show=False`` the figure is silently discarded. :param grid_width_n_points: The number of points along each axis of the visualization grid. :param max_n_samples: The maximum number of samples to visualize. .. py:class:: ManyWell(dim = 32, device = torch.device('cpu'), seed = 0) Bases: :py:obj:`BaseTarget` Many-well target distribution. The 32D (default) instance is the concatenation of 16 identical 2D double-well components. Each 2D block (x1, x2) has unnormalized log-density log p(x1, x2) = -x1^4 + 6 x1^2 + 0.5 x1 - 0.5 x2^2 + C The overall log-density is the sum over all independent 2D blocks. Sampling uses rejection sampling for the x1 coordinate in each block with a simple Gaussian mixture proposal, and standard Normal for x2. .. py:attribute:: Z_x1 :value: 11784.50927 .. py:attribute:: Z_x2 .. py:method:: _block_log_density(x1, x2) :staticmethod: .. py:method:: _compute_envelope_k(proposal) .. py:method:: _make_proposal() .. py:method:: _rejection_sampling_x1(n_samples, proposal, k) :staticmethod: .. py:attribute:: component_mix .. py:method:: gt_logz() Log partition function of the target. :returns: The log partition function. .. py:method:: log_reward(x) Log reward function. :param x: The input tensor. :returns: The log rewards for the input tensor. .. py:attribute:: means .. py:method:: sample(batch_size, seed = None) Sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:attribute:: scales .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 500) Visualize the target. :param samples: The samples to visualize. :param show: Whether to show the plot. :param prefix: The prefix for the plot file name. :param output_dir: Directory to save the plot to. Required when ``show=False``; when *None* and ``show=False`` the figure is silently discarded. :param grid_width_n_points: The number of points along each axis of the visualization grid. :param max_n_samples: The maximum number of samples to visualize. .. py:class:: Posterior9of25GaussianMixture(device, dim = 2, scale = math.sqrt(0.3), plot_border = 15.0, seed = 0) Bases: :py:obj:`BaseTarget` Posterior reward for the 25→9 GMM RTB demo. .. py:method:: gt_logz() Log partition function of the target. :returns: The log partition function. .. py:method:: log_reward(x) Log reward function. :param x: The input tensor. :returns: The log rewards for the input tensor. .. py:attribute:: posterior .. py:attribute:: prior .. py:method:: sample(batch_size, seed = None) Sample from the target. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 1000) Visualize the target. :param samples: The samples to visualize. :param show: Whether to show the plot. :param prefix: The prefix for the plot file name. :param output_dir: Directory to save the plot to. Required when ``show=False``; when *None* and ``show=False`` the figure is silently discarded. :param grid_width_n_points: The number of points along each axis of the visualization grid. :param max_n_samples: The maximum number of samples to visualize. .. py:class:: SimpleGaussianMixture(num_components = 2, dim = 2, mean_val_range = (-10.0, 10.0), mixture_weight_range = (0.3, 0.7), degree_of_freedom_adjustment = 2, seed = 3, locs = None, device = torch.device('cpu')) Bases: :py:obj:`BaseTarget` Simple Gaussian Mixture Target distribution. This target distribution is adapted from https://github.com/DenisBless/variational_sampling_methods/blob/main/targets/gaussian_mixture.py. .. attribute:: ... .. py:attribute:: distribution .. py:method:: gt_logz() Log partition function of the target. :returns: The log partition function. .. py:method:: log_reward(x) Log reward function for the SimpleGaussianMixtureTarget. :param x: The input tensor. :returns: The log rewards for the input tensor. .. py:method:: sample(batch_size, seed = None) Sample from the SimpleGaussianMixtureTarget. :param batch_size: The number of samples to sample. :param seed: The seed for the random number generator. :returns: The samples. .. py:method:: visualize(samples = None, show = False, prefix = '', output_dir = None, grid_width_n_points = 100, max_n_samples = 500) Visualize the distribution. :param samples: The samples to visualize. :param show: Whether to show the plot. :param prefix: The prefix for the plot file name. :param output_dir: Directory to save the plot to. :param grid_width_n_points: The number of points along each axis of the visualization grid. :param max_n_samples: The maximum number of samples to visualize. .. py:data:: TargetEntry .. py:data:: logger