gfn.gflownet.mle

MLE loss for diffusion GFlowNets (forward PF with optional PB).

Key equations (per time step, shapes in comments):
  • Backward bridge (s_t -> s_{t-dt}):

    mean_bb = s_t * (1 - dt / t) # (B, s_dim) std_bb = sigma * sqrt(dt*(t-dt)/t) # (B, 1) broadcast

    With learned PB corrections:

    mean = mean_bb + mean_corr std = sqrt(std_bb^2 + corr_std^2)

  • Forward PF log-prob for increment Δ = s_t - s_{t-dt}:
    If PF predicts log_std:

    σ = exp(log_std) * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ]

    Else (fixed variance):

    σ = sigma * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ)^2 + log(2π σ^2) ]

  • Loss = -mean over batch of Σ_t log p_t

Tensor conventions:
  • terminal_states: (B, s_dim) or (B, s_dim + 1) with last dim an extra terminal indicator column; we drop the last dim if present.

  • Times: scalar dt = 1/num_steps; t_curr = 1 - i*dt; t_fwd = 1 - (i+1)*dt.

Usage (user owns optimizer/loop): ```python gfn = MLEDiffusion(pf=pf, pb=None, num_steps=100, sigma=2.0, t_scale=1.0) opt = torch.optim.Adam(gfn.parameters(), lr=1e-3) for it in n_iterations:

# Sample a batch of terminal states. batch = env.sample(batch_size) # batch shape (B, s_dim) opt.zero_grad() # Calculate the MLE loss under the backward / forward diffusion process. loss = gfn.loss(batch, exploration_std=0.0) loss.backward() opt.step()

```

Classes

MLEDiffusion

Maximum-likelihood diffusion GFlowNet (PF with optional PB).

Functions

dynamo_disable(fn)

Module Contents

class gfn.gflownet.mle.MLEDiffusion(pf, pb=None, *, num_steps, sigma, t_scale=1.0, pb_scale_range=0.1, learn_variance=False, reduction='mean', debug=False)

Bases: gfn.gflownet.base.GFlowNet

Maximum-likelihood diffusion GFlowNet (PF with optional PB).

The caller owns the training loop; this class provides:
  • sampling via the forward PF (for API compatibility)

  • .loss(env, terminal_states, …) computing the MLE objective

Parameters:
_assert_no_nan(logpf_sum)
Parameters:

logpf_sum (torch.Tensor)

Return type:

None

_extract_samples(terminal_states)

Normalize input to a (B, s_dim) tensor. Accepts torch.Tensor or States; drops a final column if size matches s_dim+1.

Parameters:

terminal_states (Any)

Return type:

tuple[torch.device, torch.dtype, torch.Tensor]

debug = False
dt
learn_variance = False
loss(env, terminal_states, recalculate_all_logprobs=True, *, exploration_std=0.0)

Compute the MLE objective given terminal states sampled from the target.

Parameters:
  • terminal_states (Any) – torch.Tensor or States; shape (B, s_dim) or (B, s_dim+1).

  • exploration_std (float | torch.Tensor) – extra state-space noise (combined in quadrature with PF std).

  • env (gfn.env.Env)

  • recalculate_all_logprobs (bool)

Returns:

Scalar loss (mean reduction).

Return type:

torch.Tensor

num_steps
pb = None
pb_scale_range = 0.1
pf
reduction = 'mean'
s_dim
sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)
Parameters:
  • env (gfn.env.Env)

  • n (int)

  • conditions (torch.Tensor | None)

  • save_logprobs (bool)

  • save_estimator_outputs (bool)

  • policy_kwargs (Any)

sampler
sigma
t_scale = 1.0
to_training_samples(trajectories)
gfn.gflownet.mle.dynamo_disable(fn)