gfn.gym.diffusion_sampling

Attributes

TargetEntry

logger

Classes

BaseTarget

Base class for all target distributions for diffusion sampling.

DiffusionSampling

Diffusion sampling environment.

Funnel

Neal's funnel distribution target.

Grid25GaussianMixture

Fixed 5x5 Gaussian mixture prior used for RTB demos.

ManyWell

Many-well target distribution.

Posterior9of25GaussianMixture

Posterior reward for the 25→9 GMM RTB demo.

SimpleGaussianMixture

Simple Gaussian Mixture Target distribution.

Module Contents

class gfn.gym.diffusion_sampling.BaseTarget(device, dim, n_gt_xs, seed=0, plot_border=None)

Bases: abc.ABC

Base class for all target distributions for diffusion sampling.

Parameters:
  • device (torch.device)

  • dim (int)

  • n_gt_xs (int)

  • seed (int)

  • plot_border (float | tuple[float, float, float, float] | None)

device

The device on which the target is stored.

dim

The dimension of the target.

gt_xs

The ground truth samples.

gt_xs_log_rewards

The log rewards of the ground truth samples.

cached_sample(batch_size, seed=None)

Cached sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples and the log rewards.

Return type:

tuple[torch.Tensor | None, torch.Tensor | None]

device
dim
grad_log_reward(x)

Gradient of the log reward function.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The gradient of the log reward function.

Return type:

torch.Tensor

abstract gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

abstract log_reward(x)

Log reward function.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The log rewards for the input tensor.

Return type:

torch.Tensor

plot_border = None
abstract sample(batch_size, seed=None)

Sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

abstract visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)

Visualize the target.

Parameters:
  • samples (torch.Tensor | None) – The samples to visualize.

  • show (bool) – Whether to show the plot.

  • prefix (str) – The prefix for the plot file name.

  • output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when show=False; when None and show=False the figure is silently discarded.

  • grid_width_n_points (int) – The number of points along each axis of the visualization grid.

  • max_n_samples (int) – The maximum number of samples to visualize.

Return type:

None

class gfn.gym.diffusion_sampling.DiffusionSampling(target_str, target_kwargs, num_discretization_steps, device=torch.device('cpu'), debug=False)

Bases: gfn.env.Env

Diffusion sampling environment.

Parameters:
  • target_str (str)

  • target_kwargs (dict[str, Any] | None)

  • num_discretization_steps (float)

  • device (torch.device)

  • debug (bool)

target

The target distribution.

device

The device to use.

Return type:

torch.device

DIFFUSION_TARGETS: dict[str, TargetEntry]
backward_step(states, actions)

Backward step function for the SimpleGaussianMixtureModel environment.

Parameters:
Returns:

The previous states.

Return type:

gfn.states.States

static density_metrics(fwd_log_pfs, fwd_log_pbs, fwd_log_rewards, log_Z_learned, bwd_log_pfs=None, bwd_log_pbs=None, bwd_log_rewards=None, gt_log_Z=None)
Parameters:
  • fwd_log_pfs (torch.Tensor)

  • fwd_log_pbs (torch.Tensor)

  • fwd_log_rewards (torch.Tensor)

  • log_Z_learned (float)

  • bwd_log_pfs (torch.Tensor | None)

  • bwd_log_pbs (torch.Tensor | None)

  • bwd_log_rewards (torch.Tensor | None)

  • gt_log_Z (float | None)

Return type:

dict

dim
dt
is_action_valid(states, actions, backward=False)

Check if the actions are valid.

Parameters:
Returns:

True if the actions are valid, False otherwise.

Return type:

bool

classmethod list_available_targets()

Return metadata about available targets and their default kwargs.

This helper allows users to easily see which kwargs are provided by default for each alias. Note that accepted/required kwargs are determined by each target class’s constructor signature.

Return type:

dict[str, dict[str, Any]]

log_reward(states)

Log reward function for the DiffusionSampling environment.

Parameters:

states (gfn.states.States) – The current states.

Returns:

The log rewards for the input states.

Return type:

torch.Tensor

make_states_class()

Returns the States class for diffusion sampling.

Return type:

type[gfn.states.States]

step(states, actions)

Step function for the SimpleGaussianMixtureModel environment.

Parameters:
Returns:

The next states.

Return type:

gfn.states.States

target
class gfn.gym.diffusion_sampling.Funnel(dim=10, std=1.0, device=torch.device('cpu'), seed=0)

Bases: BaseTarget

Neal’s funnel distribution target.

x0 ~ Normal(0, std^2), and for i >= 1: xi | x0 ~ Normal(0, exp(x0)).

Parameters:
  • dim (int) – Total dimensionality (x0 plus dim-1 conditional coordinates).

  • std (float) – Standard deviation for the marginal prior on x0.

  • device (torch.device) – Torch device.

  • seed (int) – RNG seed.

dist_dominant
gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

log_reward(x)

Log-density of Neal’s funnel distribution.

Returns log p(x0) + sum_i log p(xi | x0), i=1..dim-1.

Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

sample(batch_size, seed=None)

Sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)

Visualize only supported for 2D (x0, x1).

Parameters:
  • samples (torch.Tensor | None)

  • show (bool)

  • prefix (str)

  • output_dir (str | pathlib.Path | None)

  • grid_width_n_points (int)

  • max_n_samples (int)

Return type:

None

class gfn.gym.diffusion_sampling.Grid25GaussianMixture(device, dim=2, scale=math.sqrt(0.3), plot_border=15.0, seed=0)

Bases: BaseTarget

Fixed 5x5 Gaussian mixture prior used for RTB demos.

Parameters:
  • device (torch.device)

  • dim (int)

  • scale (float)

  • plot_border (float)

  • seed (int)

gmm
gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

locs
log_reward(x)

Log reward function.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The log rewards for the input tensor.

Return type:

torch.Tensor

sample(batch_size, seed=None)

Sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)

Visualize the target.

Parameters:
  • samples (torch.Tensor | None) – The samples to visualize.

  • show (bool) – Whether to show the plot.

  • prefix (str) – The prefix for the plot file name.

  • output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when show=False; when None and show=False the figure is silently discarded.

  • grid_width_n_points (int) – The number of points along each axis of the visualization grid.

  • max_n_samples (int) – The maximum number of samples to visualize.

Return type:

None

class gfn.gym.diffusion_sampling.ManyWell(dim=32, device=torch.device('cpu'), seed=0)

Bases: BaseTarget

Many-well target distribution.

The 32D (default) instance is the concatenation of 16 identical 2D double-well components. Each 2D block (x1, x2) has unnormalized log-density

log p(x1, x2) = -x1^4 + 6 x1^2 + 0.5 x1 - 0.5 x2^2 + C

The overall log-density is the sum over all independent 2D blocks.

Sampling uses rejection sampling for the x1 coordinate in each block with a simple Gaussian mixture proposal, and standard Normal for x2.

Parameters:
  • dim (int)

  • device (torch.device)

  • seed (int)

Z_x1 = 11784.50927
Z_x2
static _block_log_density(x1, x2)
Parameters:
  • x1 (torch.Tensor)

  • x2 (torch.Tensor)

Return type:

torch.Tensor

_compute_envelope_k(proposal)
Parameters:

proposal (torch.distributions.Distribution)

Return type:

float

_make_proposal()
Return type:

torch.distributions.MixtureSameFamily

static _rejection_sampling_x1(n_samples, proposal, k)
Parameters:
  • n_samples (int)

  • proposal (torch.distributions.Distribution)

  • k (float)

Return type:

torch.Tensor

component_mix
gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

log_reward(x)

Log reward function.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The log rewards for the input tensor.

Return type:

torch.Tensor

means
sample(batch_size, seed=None)

Sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

scales
visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)

Visualize the target.

Parameters:
  • samples (torch.Tensor | None) – The samples to visualize.

  • show (bool) – Whether to show the plot.

  • prefix (str) – The prefix for the plot file name.

  • output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when show=False; when None and show=False the figure is silently discarded.

  • grid_width_n_points (int) – The number of points along each axis of the visualization grid.

  • max_n_samples (int) – The maximum number of samples to visualize.

Return type:

None

class gfn.gym.diffusion_sampling.Posterior9of25GaussianMixture(device, dim=2, scale=math.sqrt(0.3), plot_border=15.0, seed=0)

Bases: BaseTarget

Posterior reward for the 25→9 GMM RTB demo.

Parameters:
  • device (torch.device)

  • dim (int)

  • scale (float)

  • plot_border (float)

  • seed (int)

gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

log_reward(x)

Log reward function.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The log rewards for the input tensor.

Return type:

torch.Tensor

posterior
prior
sample(batch_size, seed=None)

Sample from the target.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)

Visualize the target.

Parameters:
  • samples (torch.Tensor | None) – The samples to visualize.

  • show (bool) – Whether to show the plot.

  • prefix (str) – The prefix for the plot file name.

  • output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when show=False; when None and show=False the figure is silently discarded.

  • grid_width_n_points (int) – The number of points along each axis of the visualization grid.

  • max_n_samples (int) – The maximum number of samples to visualize.

Return type:

None

class gfn.gym.diffusion_sampling.SimpleGaussianMixture(num_components=2, dim=2, mean_val_range=(-10.0, 10.0), mixture_weight_range=(0.3, 0.7), degree_of_freedom_adjustment=2, seed=3, locs=None, device=torch.device('cpu'))

Bases: BaseTarget

Simple Gaussian Mixture Target distribution.

This target distribution is adapted from https://github.com/DenisBless/variational_sampling_methods/blob/main/targets/gaussian_mixture.py.

Parameters:
  • num_components (int)

  • dim (int)

  • mean_val_range (tuple[float, float])

  • mixture_weight_range (tuple[float, float])

  • degree_of_freedom_adjustment (int)

  • seed (int)

  • locs (numpy.ndarray | None)

  • device (torch.device)

...
distribution
gt_logz()

Log partition function of the target.

Returns:

The log partition function.

Return type:

float

log_reward(x)

Log reward function for the SimpleGaussianMixtureTarget.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The log rewards for the input tensor.

Return type:

torch.Tensor

sample(batch_size, seed=None)

Sample from the SimpleGaussianMixtureTarget.

Parameters:
  • batch_size (int) – The number of samples to sample.

  • seed (int | None) – The seed for the random number generator.

Returns:

The samples.

Return type:

torch.Tensor

visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)

Visualize the distribution.

Parameters:
  • samples (torch.Tensor | None) – The samples to visualize.

  • show (bool) – Whether to show the plot.

  • prefix (str) – The prefix for the plot file name.

  • output_dir (str | pathlib.Path | None) – Directory to save the plot to.

  • grid_width_n_points (int) – The number of points along each axis of the visualization grid.

  • max_n_samples (int) – The maximum number of samples to visualize.

Return type:

None

gfn.gym.diffusion_sampling.TargetEntry
gfn.gym.diffusion_sampling.logger