gfn.gym.diffusion_sampling¶
Attributes¶
Classes¶
Base class for all target distributions for diffusion sampling. |
|
Diffusion sampling environment. |
|
Neal's funnel distribution target. |
|
Fixed 5x5 Gaussian mixture prior used for RTB demos. |
|
Many-well target distribution. |
|
Posterior reward for the 25→9 GMM RTB demo. |
|
Simple Gaussian Mixture Target distribution. |
Module Contents¶
- class gfn.gym.diffusion_sampling.BaseTarget(device, dim, n_gt_xs, seed=0, plot_border=None)¶
Bases:
abc.ABCBase class for all target distributions for diffusion sampling.
- Parameters:
device (torch.device)
dim (int)
n_gt_xs (int)
seed (int)
plot_border (float | tuple[float, float, float, float] | None)
- device¶
The device on which the target is stored.
- dim¶
The dimension of the target.
- gt_xs¶
The ground truth samples.
- gt_xs_log_rewards¶
The log rewards of the ground truth samples.
- cached_sample(batch_size, seed=None)¶
Cached sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples and the log rewards.
- Return type:
tuple[torch.Tensor | None, torch.Tensor | None]
- device¶
- dim¶
- grad_log_reward(x)¶
Gradient of the log reward function.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The gradient of the log reward function.
- Return type:
torch.Tensor
- abstract gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- abstract log_reward(x)¶
Log reward function.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The log rewards for the input tensor.
- Return type:
torch.Tensor
- plot_border = None¶
- abstract sample(batch_size, seed=None)¶
Sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- abstract visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)¶
Visualize the target.
- Parameters:
samples (torch.Tensor | None) – The samples to visualize.
show (bool) – Whether to show the plot.
prefix (str) – The prefix for the plot file name.
output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when
show=False; when None andshow=Falsethe figure is silently discarded.grid_width_n_points (int) – The number of points along each axis of the visualization grid.
max_n_samples (int) – The maximum number of samples to visualize.
- Return type:
None
- class gfn.gym.diffusion_sampling.DiffusionSampling(target_str, target_kwargs, num_discretization_steps, device=torch.device('cpu'), debug=False)¶
Bases:
gfn.env.EnvDiffusion sampling environment.
- Parameters:
target_str (str)
target_kwargs (dict[str, Any] | None)
num_discretization_steps (float)
device (torch.device)
debug (bool)
- target¶
The target distribution.
- device¶
The device to use.
- Return type:
torch.device
- DIFFUSION_TARGETS: dict[str, TargetEntry]¶
- backward_step(states, actions)¶
Backward step function for the SimpleGaussianMixtureModel environment.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions, which correspond to the changes to the states.
- Returns:
The previous states.
- Return type:
- static density_metrics(fwd_log_pfs, fwd_log_pbs, fwd_log_rewards, log_Z_learned, bwd_log_pfs=None, bwd_log_pbs=None, bwd_log_rewards=None, gt_log_Z=None)¶
- Parameters:
fwd_log_pfs (torch.Tensor)
fwd_log_pbs (torch.Tensor)
fwd_log_rewards (torch.Tensor)
log_Z_learned (float)
bwd_log_pfs (torch.Tensor | None)
bwd_log_pbs (torch.Tensor | None)
bwd_log_rewards (torch.Tensor | None)
gt_log_Z (float | None)
- Return type:
dict
- dim¶
- dt¶
- is_action_valid(states, actions, backward=False)¶
Check if the actions are valid.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions to check.
backward (bool)
- Returns:
True if the actions are valid, False otherwise.
- Return type:
bool
- classmethod list_available_targets()¶
Return metadata about available targets and their default kwargs.
This helper allows users to easily see which kwargs are provided by default for each alias. Note that accepted/required kwargs are determined by each target class’s constructor signature.
- Return type:
dict[str, dict[str, Any]]
- log_reward(states)¶
Log reward function for the DiffusionSampling environment.
- Parameters:
states (gfn.states.States) – The current states.
- Returns:
The log rewards for the input states.
- Return type:
torch.Tensor
- make_states_class()¶
Returns the States class for diffusion sampling.
- Return type:
type[gfn.states.States]
- step(states, actions)¶
Step function for the SimpleGaussianMixtureModel environment.
- Parameters:
states (gfn.states.States) – The current states.
actions (gfn.actions.Actions) – The actions, which correspond to the changes to the states.
- Returns:
The next states.
- Return type:
- target¶
- class gfn.gym.diffusion_sampling.Funnel(dim=10, std=1.0, device=torch.device('cpu'), seed=0)¶
Bases:
BaseTargetNeal’s funnel distribution target.
x0 ~ Normal(0, std^2), and for i >= 1: xi | x0 ~ Normal(0, exp(x0)).
- Parameters:
dim (int) – Total dimensionality (x0 plus dim-1 conditional coordinates).
std (float) – Standard deviation for the marginal prior on x0.
device (torch.device) – Torch device.
seed (int) – RNG seed.
- dist_dominant¶
- gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- log_reward(x)¶
Log-density of Neal’s funnel distribution.
Returns log p(x0) + sum_i log p(xi | x0), i=1..dim-1.
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- sample(batch_size, seed=None)¶
Sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)¶
Visualize only supported for 2D (x0, x1).
- Parameters:
samples (torch.Tensor | None)
show (bool)
prefix (str)
output_dir (str | pathlib.Path | None)
grid_width_n_points (int)
max_n_samples (int)
- Return type:
None
- class gfn.gym.diffusion_sampling.Grid25GaussianMixture(device, dim=2, scale=math.sqrt(0.3), plot_border=15.0, seed=0)¶
Bases:
BaseTargetFixed 5x5 Gaussian mixture prior used for RTB demos.
- Parameters:
device (torch.device)
dim (int)
scale (float)
plot_border (float)
seed (int)
- gmm¶
- gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- locs¶
- log_reward(x)¶
Log reward function.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The log rewards for the input tensor.
- Return type:
torch.Tensor
- sample(batch_size, seed=None)¶
Sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)¶
Visualize the target.
- Parameters:
samples (torch.Tensor | None) – The samples to visualize.
show (bool) – Whether to show the plot.
prefix (str) – The prefix for the plot file name.
output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when
show=False; when None andshow=Falsethe figure is silently discarded.grid_width_n_points (int) – The number of points along each axis of the visualization grid.
max_n_samples (int) – The maximum number of samples to visualize.
- Return type:
None
- class gfn.gym.diffusion_sampling.ManyWell(dim=32, device=torch.device('cpu'), seed=0)¶
Bases:
BaseTargetMany-well target distribution.
The 32D (default) instance is the concatenation of 16 identical 2D double-well components. Each 2D block (x1, x2) has unnormalized log-density
log p(x1, x2) = -x1^4 + 6 x1^2 + 0.5 x1 - 0.5 x2^2 + C
The overall log-density is the sum over all independent 2D blocks.
Sampling uses rejection sampling for the x1 coordinate in each block with a simple Gaussian mixture proposal, and standard Normal for x2.
- Parameters:
dim (int)
device (torch.device)
seed (int)
- Z_x1 = 11784.50927¶
- Z_x2¶
- static _block_log_density(x1, x2)¶
- Parameters:
x1 (torch.Tensor)
x2 (torch.Tensor)
- Return type:
torch.Tensor
- _compute_envelope_k(proposal)¶
- Parameters:
proposal (torch.distributions.Distribution)
- Return type:
float
- _make_proposal()¶
- Return type:
torch.distributions.MixtureSameFamily
- static _rejection_sampling_x1(n_samples, proposal, k)¶
- Parameters:
n_samples (int)
proposal (torch.distributions.Distribution)
k (float)
- Return type:
torch.Tensor
- component_mix¶
- gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- log_reward(x)¶
Log reward function.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The log rewards for the input tensor.
- Return type:
torch.Tensor
- means¶
- sample(batch_size, seed=None)¶
Sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- scales¶
- visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)¶
Visualize the target.
- Parameters:
samples (torch.Tensor | None) – The samples to visualize.
show (bool) – Whether to show the plot.
prefix (str) – The prefix for the plot file name.
output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when
show=False; when None andshow=Falsethe figure is silently discarded.grid_width_n_points (int) – The number of points along each axis of the visualization grid.
max_n_samples (int) – The maximum number of samples to visualize.
- Return type:
None
- class gfn.gym.diffusion_sampling.Posterior9of25GaussianMixture(device, dim=2, scale=math.sqrt(0.3), plot_border=15.0, seed=0)¶
Bases:
BaseTargetPosterior reward for the 25→9 GMM RTB demo.
- Parameters:
device (torch.device)
dim (int)
scale (float)
plot_border (float)
seed (int)
- gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- log_reward(x)¶
Log reward function.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The log rewards for the input tensor.
- Return type:
torch.Tensor
- posterior¶
- prior¶
- sample(batch_size, seed=None)¶
Sample from the target.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=1000)¶
Visualize the target.
- Parameters:
samples (torch.Tensor | None) – The samples to visualize.
show (bool) – Whether to show the plot.
prefix (str) – The prefix for the plot file name.
output_dir (str | pathlib.Path | None) – Directory to save the plot to. Required when
show=False; when None andshow=Falsethe figure is silently discarded.grid_width_n_points (int) – The number of points along each axis of the visualization grid.
max_n_samples (int) – The maximum number of samples to visualize.
- Return type:
None
- class gfn.gym.diffusion_sampling.SimpleGaussianMixture(num_components=2, dim=2, mean_val_range=(-10.0, 10.0), mixture_weight_range=(0.3, 0.7), degree_of_freedom_adjustment=2, seed=3, locs=None, device=torch.device('cpu'))¶
Bases:
BaseTargetSimple Gaussian Mixture Target distribution.
This target distribution is adapted from https://github.com/DenisBless/variational_sampling_methods/blob/main/targets/gaussian_mixture.py.
- Parameters:
num_components (int)
dim (int)
mean_val_range (tuple[float, float])
mixture_weight_range (tuple[float, float])
degree_of_freedom_adjustment (int)
seed (int)
locs (numpy.ndarray | None)
device (torch.device)
- ...
- distribution¶
- gt_logz()¶
Log partition function of the target.
- Returns:
The log partition function.
- Return type:
float
- log_reward(x)¶
Log reward function for the SimpleGaussianMixtureTarget.
- Parameters:
x (torch.Tensor) – The input tensor.
- Returns:
The log rewards for the input tensor.
- Return type:
torch.Tensor
- sample(batch_size, seed=None)¶
Sample from the SimpleGaussianMixtureTarget.
- Parameters:
batch_size (int) – The number of samples to sample.
seed (int | None) – The seed for the random number generator.
- Returns:
The samples.
- Return type:
torch.Tensor
- visualize(samples=None, show=False, prefix='', output_dir=None, grid_width_n_points=100, max_n_samples=500)¶
Visualize the distribution.
- Parameters:
samples (torch.Tensor | None) – The samples to visualize.
show (bool) – Whether to show the plot.
prefix (str) – The prefix for the plot file name.
output_dir (str | pathlib.Path | None) – Directory to save the plot to.
grid_width_n_points (int) – The number of points along each axis of the visualization grid.
max_n_samples (int) – The maximum number of samples to visualize.
- Return type:
None
- gfn.gym.diffusion_sampling.TargetEntry¶
- gfn.gym.diffusion_sampling.logger¶