gfn.gflownet.mle¶
MLE loss for diffusion GFlowNets (forward PF with optional PB).
- Key equations (per time step, shapes in comments):
- Backward bridge (s_t -> s_{t-dt}):
mean_bb = s_t * (1 - dt / t) # (B, s_dim) std_bb = sigma * sqrt(dt*(t-dt)/t) # (B, 1) broadcast
- With learned PB corrections:
mean = mean_bb + mean_corr std = sqrt(std_bb^2 + corr_std^2)
- Forward PF log-prob for increment Δ = s_t - s_{t-dt}:
- If PF predicts log_std:
σ = exp(log_std) * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ]
- Else (fixed variance):
σ = sigma * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ)^2 + log(2π σ^2) ]
Loss = -mean over batch of Σ_t log p_t
- Tensor conventions:
terminal_states: (B, s_dim) or (B, s_dim + 1) with last dim an extra terminal indicator column; we drop the last dim if present.
Times: scalar dt = 1/num_steps; t_curr = 1 - i*dt; t_fwd = 1 - (i+1)*dt.
Usage (user owns optimizer/loop): ```python gfn = MLEDiffusion(pf=pf, pb=None, num_steps=100, sigma=2.0, t_scale=1.0) opt = torch.optim.Adam(gfn.parameters(), lr=1e-3) for it in n_iterations:
# Sample a batch of terminal states. batch = env.sample(batch_size) # batch shape (B, s_dim) opt.zero_grad() # Calculate the MLE loss under the backward / forward diffusion process. loss = gfn.loss(batch, exploration_std=0.0) loss.backward() opt.step()
Classes¶
Maximum-likelihood diffusion GFlowNet (PF with optional PB). |
Functions¶
|
Module Contents¶
- class gfn.gflownet.mle.MLEDiffusion(pf, pb=None, *, num_steps, sigma, t_scale=1.0, pb_scale_range=0.1, learn_variance=False, reduction='mean', debug=False)¶
Bases:
gfn.gflownet.base.GFlowNetMaximum-likelihood diffusion GFlowNet (PF with optional PB).
- The caller owns the training loop; this class provides:
sampling via the forward PF (for API compatibility)
.loss(env, terminal_states, …) computing the MLE objective
- Parameters:
pb (Optional[gfn.estimators.PinnedBrownianMotionBackward])
num_steps (int)
sigma (float)
t_scale (float)
pb_scale_range (float)
learn_variance (bool)
reduction (str)
debug (bool)
- _assert_no_nan(logpf_sum)¶
- Parameters:
logpf_sum (torch.Tensor)
- Return type:
None
- _extract_samples(terminal_states)¶
Normalize input to a (B, s_dim) tensor. Accepts torch.Tensor or States; drops a final column if size matches s_dim+1.
- Parameters:
terminal_states (Any)
- Return type:
tuple[torch.device, torch.dtype, torch.Tensor]
- debug = False¶
- dt¶
- learn_variance = False¶
- loss(env, terminal_states, recalculate_all_logprobs=True, *, exploration_std=0.0)¶
Compute the MLE objective given terminal states sampled from the target.
- Parameters:
terminal_states (Any) – torch.Tensor or States; shape (B, s_dim) or (B, s_dim+1).
exploration_std (float | torch.Tensor) – extra state-space noise (combined in quadrature with PF std).
env (gfn.env.Env)
recalculate_all_logprobs (bool)
- Returns:
Scalar loss (mean reduction).
- Return type:
torch.Tensor
- num_steps¶
- pb = None¶
- pb_scale_range = 0.1¶
- pf¶
- reduction = 'mean'¶
- s_dim¶
- sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)¶
- Parameters:
env (gfn.env.Env)
n (int)
conditions (torch.Tensor | None)
save_logprobs (bool)
save_estimator_outputs (bool)
policy_kwargs (Any)
- sampler¶
- sigma¶
- t_scale = 1.0¶
- to_training_samples(trajectories)¶
- gfn.gflownet.mle.dynamo_disable(fn)¶