gfn.gym.helpers.diffusion_utils

Functions

sliced_log_reward(x, target, dims)

viz_2d_slice(ax, target, dims, samples, plot_border[, ...])

Module Contents

gfn.gym.helpers.diffusion_utils.sliced_log_reward(x, target, dims)
Parameters:
Return type:

torch.Tensor

gfn.gym.helpers.diffusion_utils.viz_2d_slice(ax, target, dims, samples, plot_border, alpha=0.5, n_contour_levels=50, grid_width_n_points=200, log_reward_clamp_min=-10000.0, use_log_reward=False, max_n_samples=None)
Parameters:
  • ax (matplotlib.axes.Axes)

  • target (gfn.gym.diffusion_sampling.BaseTarget)

  • dims (tuple)

  • samples (torch.Tensor | None)

  • plot_border (tuple[float, float, float, float])

  • max_n_samples (int | None)

Return type:

None