gfn.gflownet.mle ================ .. py:module:: gfn.gflownet.mle .. autoapi-nested-parse:: MLE loss for diffusion GFlowNets (forward PF with optional PB). Key equations (per time step, shapes in comments): - Backward bridge (s_t -> s_{t-dt}): mean_bb = s_t * (1 - dt / t) # (B, s_dim) std_bb = sigma * sqrt(dt*(t-dt)/t) # (B, 1) broadcast With learned PB corrections: mean = mean_bb + mean_corr std = sqrt(std_bb^2 + corr_std^2) - Forward PF log-prob for increment Δ = s_t - s_{t-dt}: If PF predicts log_std: σ = exp(log_std) * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ_i)^2 + 2 log σ_i + log 2π ] Else (fixed variance): σ = sigma * sqrt(dt) * sqrt(t_scale); optionally combine exploration log p = -0.5 * Σ_i [ ((Δ - dt μ)_i / σ)^2 + log(2π σ^2) ] - Loss = -mean over batch of Σ_t log p_t Tensor conventions: - terminal_states: (B, s_dim) or (B, s_dim + 1) with last dim an extra terminal indicator column; we drop the last dim if present. - Times: scalar dt = 1/num_steps; t_curr = 1 - i*dt; t_fwd = 1 - (i+1)*dt. Usage (user owns optimizer/loop): ```python gfn = MLEDiffusion(pf=pf, pb=None, num_steps=100, sigma=2.0, t_scale=1.0) opt = torch.optim.Adam(gfn.parameters(), lr=1e-3) for it in n_iterations: # Sample a batch of terminal states. batch = env.sample(batch_size) # batch shape (B, s_dim) opt.zero_grad() # Calculate the MLE loss under the backward / forward diffusion process. loss = gfn.loss(batch, exploration_std=0.0) loss.backward() opt.step() ``` Classes ------- .. autoapisummary:: gfn.gflownet.mle.MLEDiffusion Functions --------- .. autoapisummary:: gfn.gflownet.mle.dynamo_disable Module Contents --------------- .. py:class:: MLEDiffusion(pf, pb = None, *, num_steps, sigma, t_scale = 1.0, pb_scale_range = 0.1, learn_variance = False, reduction = 'mean', debug = False) Bases: :py:obj:`gfn.gflownet.base.GFlowNet` Maximum-likelihood diffusion GFlowNet (PF with optional PB). The caller owns the training loop; this class provides: - sampling via the forward PF (for API compatibility) - `.loss(env, terminal_states, ...)` computing the MLE objective .. py:method:: _assert_no_nan(logpf_sum) .. py:method:: _extract_samples(terminal_states) Normalize input to a (B, s_dim) tensor. Accepts torch.Tensor or States; drops a final column if size matches s_dim+1. .. py:attribute:: debug :value: False .. py:attribute:: dt .. py:attribute:: learn_variance :value: False .. py:method:: loss(env, terminal_states, recalculate_all_logprobs = True, *, exploration_std = 0.0) Compute the MLE objective given terminal states sampled from the target. :param terminal_states: torch.Tensor or States; shape (B, s_dim) or (B, s_dim+1). :param exploration_std: extra state-space noise (combined in quadrature with PF std). :returns: Scalar loss (mean reduction). .. py:attribute:: num_steps .. py:attribute:: pb :value: None .. py:attribute:: pb_scale_range :value: 0.1 .. py:attribute:: pf .. py:attribute:: reduction :value: 'mean' .. py:attribute:: s_dim .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) .. py:attribute:: sampler .. py:attribute:: sigma .. py:attribute:: t_scale :value: 1.0 .. py:method:: to_training_samples(trajectories) .. py:function:: dynamo_disable(fn)