gfn.gym.helpers.diffusion_utils¶
Functions¶
|
|
|
Module Contents¶
- gfn.gym.helpers.diffusion_utils.sliced_log_reward(x, target, dims)¶
- Parameters:
x (torch.Tensor)
dims (tuple)
- Return type:
torch.Tensor
- gfn.gym.helpers.diffusion_utils.viz_2d_slice(ax, target, dims, samples, plot_border, alpha=0.5, n_contour_levels=50, grid_width_n_points=200, log_reward_clamp_min=-10000.0, use_log_reward=False, max_n_samples=None)¶
- Parameters:
ax (matplotlib.axes.Axes)
dims (tuple)
samples (torch.Tensor | None)
plot_border (tuple[float, float, float, float])
max_n_samples (int | None)
- Return type:
None