Recurrent and Non-Autoregressive Policies¶
The default policy estimators in torchgfn are feedforward: they process each state independently with no memory of previous steps. For sequential generation tasks where context from prior steps is useful, torchgfn provides recurrent policy estimators. Separately, non-autoregressive environments allow the same terminal state to be reached via multiple action orderings.
Recurrent Policies¶
RecurrentDiscretePolicyEstimator¶
Class: RecurrentDiscretePolicyEstimator
Wraps a recurrent neural network (LSTM or GRU) that maintains a hidden state (“carry”) across trajectory steps. This allows the policy to condition on the full history of actions taken so far.
from gfn.estimators import RecurrentDiscretePolicyEstimator
from gfn.modules import RecurrentDiscreteSequenceModel
model = RecurrentDiscreteSequenceModel(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions,
hidden_dim=64,
rnn_type="lstm", # or "gru"
)
pf_estimator = RecurrentDiscretePolicyEstimator(module=model, n_actions=env.n_actions)
Carry Management¶
The recurrent hidden state is managed automatically by the RecurrentPolicyMixin:
init_carry(batch_size)— called by the sampler at the start of each rollout to allocate the initial hidden stateThe carry is updated at each step and stored in the
RolloutContextWhen computing log-probabilities for existing trajectories (e.g., during off-policy loss), the carry is re-threaded step-by-step
Tree DAG Simplification¶
For environments where the DAG is a tree (each terminal state is reachable by exactly one path), the backward policy is uniform and constant. In this case, you can skip the backward policy entirely:
gflownet = TBGFlowNet(pf=pf_estimator, pb=None, init_logZ=0.0, constant_pb=True)
Setting pb=None and constant_pb=True tells the loss function that the backward probabilities are uniform, avoiding the need to learn or evaluate a backward model.
See: train_bitsequence_recurrent.py (LSTM policy on bit sequences with tree DAG).
Non-Autoregressive Generation¶
NonAutoregressiveBitSequence¶
In autoregressive environments, actions are applied in a fixed order (e.g., left-to-right). Non-autoregressive environments allow the same terminal state to be reached by filling in positions in any order, creating a richer DAG structure.
The NonAutoregressiveBitSequence environment demonstrates this: each action specifies both a position and a value (action = position * n_words + word), and the action mask prevents filling already-occupied positions.
This has implications for training:
The DAG has more paths to each terminal state, providing more training signal per terminal state
The backward policy must account for multiple valid predecessors
Action masking is essential to prevent invalid actions
See: train_bitsequence_non_autoregressive.py.
When to Use Recurrent Policies¶
Recurrent policies are most useful when:
The optimal action depends on the sequence of previous actions, not just the current state
The state representation doesn’t fully capture the relevant history
You’re generating sequences where long-range dependencies matter
For most environments (grids, graphs, continuous spaces), feedforward policies with a good state representation are sufficient and faster to train. Recurrent policies add overhead from sequential carry threading and are harder to parallelize.
Comparing Approaches¶
Approach |
Policy type |
DAG structure |
Best for |
|---|---|---|---|
Feedforward (default) |
|
Any |
Most environments |
Recurrent |
|
Any |
Sequential generation with history dependence |
Non-autoregressive |
|
Richer (multiple paths) |
Order-invariant generation |
Tree DAG + no PB |
Any PF, |
Tree only |
Simplification when backward policy is uniform |