Conditional GFlowNets¶
A conditional GFlowNet learns to sample from different distributions depending on an external condition variable. Instead of learning a single distribution p(x) ∝ R(x), it learns a family of distributions p(x | c) ∝ R(x, c) parameterized by a condition c.
This is useful when:
The reward function depends on an external parameter (e.g., temperature, task ID)
You want a single model that interpolates between different target distributions
The condition represents side information that changes the desirable outputs
Conditional Estimators¶
torchgfn provides conditional variants of the standard estimators:
ConditionalDiscretePolicyEstimator¶
Wraps a neural network that takes both state and condition as input to produce action logits:
from gfn.estimators import ConditionalDiscretePolicyEstimator
pf_estimator = ConditionalDiscretePolicyEstimator(
module=pf_module,
n_actions=env.n_actions,
preprocessor=preprocessor,
is_backward=False,
)
The module must accept concatenated state and condition encodings. A typical architecture uses separate encoders for state and condition, then merges them:
state_encoder = MLP(state_dim, hidden_dim, hidden_dim)
condition_encoder = MLP(condition_dim, hidden_dim, hidden_dim)
trunk = MLP(2 * hidden_dim, hidden_dim, output_dim)
ConditionalLogZEstimator¶
Estimates the log-partition function as a function of the condition only (not the state):
from gfn.estimators import ConditionalLogZEstimator
logz_estimator = ConditionalLogZEstimator(module=logz_module)
This makes sense because Z depends on the reward landscape, which changes with the condition.
ConditionalScalarEstimator¶
For DB/SubTB losses, estimates log state-flow conditioned on both state and condition:
from gfn.estimators import ConditionalScalarEstimator
logf_estimator = ConditionalScalarEstimator(
module=logf_module,
preprocessor=preprocessor,
)
Sampling with Conditions¶
Pass conditions when sampling trajectories:
# Sample a batch of conditions
conditions = torch.rand(batch_size, condition_dim, device=device)
trajectories = gflownet.sample_trajectories(
env, n=batch_size, conditions=conditions
)
The conditions are threaded through to all estimators automatically.
Validation¶
Validate across a range of condition values to ensure the model generalizes:
for c_value in [0.0, 0.25, 0.5, 0.75, 1.0]:
conditions = torch.full((n_val, condition_dim), c_value, device=device)
terminating_states = gflownet.sample_terminating_states(env, n=n_val, conditions=conditions)
# Compare against env.true_dist(conditions=c_value)
Check that the learned distribution matches the target at each condition value, including edge cases (e.g., uniform distribution at one extreme, multimodal at the other).
Supported Loss Functions¶
All standard losses (TB, DB, SubTB, FM, ZVar) work with conditional estimators. The GFlowNet classes (TBGFlowNet, DBGFlowNet, etc.) handle conditions transparently — no changes to the loss computation code are needed.
Example Environment¶
ConditionalHyperGrid is a built-in environment where the reward function is parameterized by a continuous condition. It provides true_dist() and log_partition() methods that accept conditions, making it easy to validate the learned conditional distributions.
See: train_conditional.py for a complete example with five loss variants, condition-dependent validation, and interpolation quality metrics.