PolicyMixin: Policies and Rollouts¶
Estimators become policy-capable by mixing in a small, uniform rollout API. This lets the same Sampler and probability utilities drive different estimator families (discrete, graph, conditional, recurrent) without bespoke glue code.
This guide explains:
The Policy rollout API and
RolloutContextVectorized vs non‑vectorized probability paths
How policies integrate with the
Samplerand probability calculatorsHow to implement a new policy mixin or tailor the default behavior
Concepts and Goals¶
A policy‑capable estimator exposes:
is_vectorized: bool— whether the estimator can be evaluated in a single vectorized call (no per‑step carry).init_context(batch_size, device, conditions)— allocate a per‑rollout context.compute_dist(states_active, ctx, step_mask, ...) -> (Distribution, ctx)— run the model, build atorch.distributions.Distribution.log_probs(actions_active, dist, ctx, step_mask, vectorized, ...) -> (Tensor, ctx)— evaluate log‑probs, optionally padded to batch.get_current_estimator_output(ctx)— access the last raw model output when requested.
All per‑step artifacts (e.g., log‑probs, raw outputs, recurrent state) are owned by the RolloutContext and recorded by the mixin.
RolloutContext¶
The RolloutContext is a lightweight container created once per rollout:
batch_size,device, optionalconditionsOptional
carry(for recurrent policies)Per‑step buffers:
trajectory_log_probs,trajectory_estimator_outputscurrent_estimator_outputfor cached reuse or immediate retrievalextras: dictfor arbitrary policy‑specific data
See src/gfn/estimators.py for the full definition.
PolicyMixin (vectorized, default)¶
PolicyMixin enables vectorized evaluation by default (is_vectorized=True).
init_context(batch_size, device, conditions)returns a freshRolloutContextwith empty buffers.compute_dist(...):Slices
conditionsbystep_maskwhen provided; uses fullconditionswhenstep_mask=None(vectorized).Optionally reuses
ctx.current_estimator_output(e.g., PF with cachedtrajectories.estimator_outputs).Calls the estimator module and builds a
Distributionviato_probability_distribution.When
save_estimator_outputs=True, setsctx.current_estimator_outputand records a padded copy toctx.trajectory_estimator_outputsfor non‑vectorized calls.
log_probs(...):vectorized=True: returns rawdist.log_prob(...)(may include-inffor illegal actions) and optionally records totrajectory_log_probs.vectorized=False: strict inf‑check, pads to shape(N,)usingstep_mask, records when requested.
Code reference (log‑probs behavior): src/gfn/estimators.py.
RecurrentPolicyMixin (per‑step)¶
RecurrentPolicyMixin sets is_vectorized=False and threads a carry through steps:
init_context(...)requires the estimator to implementinit_carry(batch_size, device); stores the result inctx.carry.compute_dist(...)must call the estimator as(states_active, ctx.carry) -> (est_out, new_carry), updatectx.carry, build theDistribution, and record outputs when requested (with padding when masked).log_probs(...)follows the non‑vectorized path (pad and strict checks) and can reuse the same recording semantics asPolicyMixin.
Code reference (carry update and padded recording): src/gfn/estimators.py.
Integration with the Sampler¶
The Sampler uses the policy API directly. It creates a single ctx per rollout, then repeats compute_dist → sample → optional log_probs while some trajectories are active. Per‑step artifacts are recorded into ctx by the mixin when flags are enabled.
Excerpt (per‑step call pattern): src/gfn/samplers.py.
Integration with probability calculators (PF/PB)¶
Probability utilities in utils/prob_calculations.py branch on is_vectorized but call the same two methods in both paths:
compute_dist(states_active, ctx, step_mask=None or mask)log_probs(actions_active, dist, ctx, step_mask=None or mask, vectorized=...)
Key differences:
Vectorized (fast path)
step_mask=None,vectorized=True.May reuse cached estimator outputs by pre‑setting
ctx.current_estimator_output.log_probsreturns rawdist.log_prob(...)and does not raise on-inf.
Non‑vectorized (per‑step path)
Uses legacy‑accurate masks and alignments:
PF (trajectories):
~states.is_sink_state[t] & ~actions.is_dummy[t]PB (trajectories): aligns action at
twith state att+1, using~states.is_sink_state[t+1] & ~states.is_initial_state[t+1] & ~actions.is_dummy[t] & ~actions.is_exit[t](skipst==0).Transitions: legacy PB mask on
next_stateswith~actions.is_exit.
log_probspads back to(N,)and raises if any±infremains after masking.
See src/gfn/utils/prob_calculations.py for full branching.
Built‑in policy‑capable estimators¶
DiscretePolicyEstimator: logits →Categoricalwith masking, optional temperature and epsilon‑greedy mixing in log‑space.DiscreteGraphPolicyEstimator: multi‑head logits (TensorDict) →GraphActionDistributionwith per‑component masks and transforms.RecurrentDiscretePolicyEstimator: sequence models that maintain acarry; requiresinit_carryand returns(logits, carry)inforward.Conditional variants exist for state+condition architectures.
How to write a new policy (or mixin variant)¶
Most users only need to implement to_probability_distribution (or reuse the provided ones). If you need a new interface or extra tracking, you can either:
Use
PolicyMixin(stateless, vectorized) and overrideto_probability_distributionon your estimator.Use
RecurrentPolicyMixin(per‑step, carry) and implementinit_carryplus aforward(states, carry)that returns(estimator_outputs, new_carry).Create a custom mixin derived from
PolicyMixinto tailorcompute_dist/log_probs(e.g., custom caching, diagnostics).
Minimal stateless policy (discrete)¶
import torch
from torch import nn
from gfn.estimators import DiscretePolicyEstimator
class SmallMLP(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.input_dim = input_dim
self.net = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(), nn.Linear(128, output_dim)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# Forward policy over n_actions
policy = DiscretePolicyEstimator(module=SmallMLP(input_dim=32, output_dim=17), n_actions=17)
Use with the Sampler:
from gfn.samplers import Sampler
sampler = Sampler(policy)
trajectories = sampler.sample_trajectories(env, n=64, save_logprobs=True)
Minimal recurrent policy¶
import torch
from torch import nn
from gfn.estimators import RecurrentDiscretePolicyEstimator
class TinyRNN(nn.Module):
def __init__(self, vocab_size: int, hidden: int):
super().__init__()
self.vocab_size = vocab_size
self.embed = nn.Embedding(vocab_size, hidden)
self.rnn = nn.GRU(hidden, hidden, batch_first=True)
self.head = nn.Linear(hidden, vocab_size)
def forward(self, tokens: torch.Tensor, carry: dict[str, torch.Tensor]):
x = self.embed(tokens)
h0 = carry.get("h", torch.zeros(1, tokens.size(0), x.size(-1), device=tokens.device))
y, h = self.rnn(x, h0)
logits = self.head(y)
return logits, {"h": h}
def init_carry(self, batch_size: int, device: torch.device) -> dict[str, torch.Tensor]:
return {"h": torch.zeros(1, batch_size, self.embed.embedding_dim, device=device)}
policy = RecurrentDiscretePolicyEstimator(module=TinyRNN(vocab_size=33, hidden=64), n_actions=33)
Custom mixin variant (advanced)¶
If you need to add diagnostics or custom caching, subclass PolicyMixin and override compute_dist/log_probs to interact with ctx.extras.
from typing import Any, Optional
from torch.distributions import Distribution
from gfn.estimators import PolicyMixin
class TracingPolicyMixin(PolicyMixin):
def compute_dist(self, states_active, ctx, step_mask=None, save_estimator_outputs=False, **kw):
dist, ctx = super().compute_dist(states_active, ctx, step_mask, save_estimator_outputs, **kw)
ctx.extras.setdefault("num_compute_calls", 0)
ctx.extras["num_compute_calls"] += 1
return dist, ctx
def log_probs(self, actions_active, dist: Distribution, ctx: Any, step_mask=None, vectorized=False, save_logprobs=False):
lp, ctx = super().log_probs(actions_active, dist, ctx, step_mask, vectorized, save_logprobs)
ctx.extras.setdefault("last_lp_mean", lp.mean().detach())
return lp, ctx
Keep is_vectorized consistent with your evaluation strategy. If you switch to False, ensure your estimator supports per‑step rollouts and masking semantics.