gfn.estimators

Attributes

REDUCTION_FUNCTIONS

_POLICY_REQUIRED_METHODS

Classes

ConditionalDiscretePolicyEstimator

Conditional forward or backward policy estimators for discrete environments.

ConditionalLogZEstimator

Conditional logZ estimator.

ConditionalScalarEstimator

Class for conditionally estimating scalars (logZ, DB/SubTB state logF).

DiffusionPolicyEstimator

Base class for diffusion policy estimators.

DiscreteGraphPolicyEstimator

Forward or backward policy estimators for graph-based environments.

DiscretePolicyEstimator

Forward or backward policy estimators for discrete environments.

Estimator

Base class for modules mapping states to distributions or scalar values.

LogitBasedEstimator

Base class for estimators that output logits.

PinnedBrownianMotionBackward

Base class for diffusion policy estimators.

PinnedBrownianMotionForward

Base class for diffusion policy estimators.

PolicyEstimatorProtocol

Static-typing surface for estimators that are policy-capable.

PolicyMixin

Mixin enabling an Estimator to act as a policy (distribution over actions).

RecurrentDiscretePolicyEstimator

Discrete policy estimator for recurrent architectures with explicit carry.

RecurrentPolicyMixin

Mixin for recurrent policies that maintain and update a rollout carry.

RolloutContext

Structured per‑rollout state owned by estimators.

ScalarEstimator

Class for estimating scalars such as logZ of TB or state/edge flows of DB/SubTB.

Functions

validate_policy_estimator(estimator[, name])

Checks that an estimator implements the PolicyMixin interface.

Module Contents

class gfn.estimators.ConditionalDiscretePolicyEstimator(state_module, condition_module, final_module, n_actions, preprocessor=None, is_backward=False, debug=False)

Bases: DiscretePolicyEstimator

Conditional forward or backward policy estimators for discrete environments.

Estimates either, with condition \(c\): - \(s \mapsto (P_F(s' \mid s, c))_{s' \in Children(s)}\) (conditional forward policy) - \(s' \mapsto (P_B(s \mid s', c))_{s \in Parents(s')}\) (conditional backward policy)

This estimator is designed for discrete environments where the policy depends on both the state and some condition information. It uses a multi-module architecture where states and conditions are processed separately before being combined.

Parameters:
  • state_module (torch.nn.Module)

  • condition_module (torch.nn.Module)

  • final_module (torch.nn.Module)

  • n_actions (int)

  • preprocessor (gfn.preprocessors.Preprocessor | None)

  • is_backward (bool)

  • debug (bool)

module

The neural network module for state processing.

condition_module

The neural network module for condition processing.

final_module

The neural network module that combines state and condition.

n_actions

Total number of actions in the discrete environment.

preprocessor

Preprocessor object that transforms raw States objects to tensors.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

_forward_trunk(states, conditions)

Forward pass of the trunk of the module.

This method processes the states and conditions inputs separately, then combines them through the final module.

Parameters:
  • states (gfn.states.States) – The input states.

  • conditions (torch.Tensor) – The condition tensor.

Returns:

The output of the trunk of the module, as a tensor of shape

(*batch_shape, output_dim).

Return type:

torch.Tensor

condition_module
final_module
forward(states, conditions)

Forward pass of the module.

Parameters:
  • states (gfn.states.States) – The input states.

  • conditions (torch.Tensor) – The condition tensor.

Returns:

The output of the module, as a tensor of shape (*batch_shape, output_dim).

Return type:

torch.Tensor

n_actions
class gfn.estimators.ConditionalLogZEstimator(module, reduction='mean')

Bases: ScalarEstimator

Conditional logZ estimator.

This estimator is used to estimate the logZ of a GFlowNet from a conditions tensor. Since the conditions are given as a tensor, it does not have a preprocessor. Reduction is used to aggregate the outputs of the module into a single scalar.

Parameters:
  • module (torch.nn.Module)

  • reduction (str)

module

The neural network module to use.

reduction

String name of one of the REDUCTION_FUNCTIONS keys.

_calculate_module_output(input)
Parameters:

input (torch.Tensor)

Return type:

torch.Tensor

class gfn.estimators.ConditionalScalarEstimator(state_module, condition_module, final_module, preprocessor=None, reduction='mean', debug=False)

Bases: ConditionalDiscretePolicyEstimator

Class for conditionally estimating scalars (logZ, DB/SubTB state logF).

Similar to ScalarEstimator, the function approximator used for final_module need not directly output a scalar. If it does not, reduction will be used to aggregate the outputs of the module into a single scalar.

Parameters:
  • state_module (torch.nn.Module)

  • condition_module (torch.nn.Module)

  • final_module (torch.nn.Module)

  • preprocessor (gfn.preprocessors.Preprocessor | None)

  • reduction (str)

  • debug (bool)

module

The neural network module for state processing.

condition_module

The neural network module for condition processing.

final_module

The neural network module that combines state and condition.

preprocessor

Preprocessor object that transforms raw States objects to tensors.

is_backward

Always False for ConditionalScalarEstimator (since it’s direction-agnostic).

reduction_function

Function used to reduce multi-dimensional outputs to scalars.

property expected_output_dim: int

Expected output dimension of the module.

Returns:

Always 1, as this estimator outputs scalar values.

Return type:

int

forward(states, conditions)

Forward pass of the module.

Parameters:
  • states (gfn.states.States) – The input states.

  • conditions (torch.Tensor) – The condition tensor.

Returns:

The output of the module, as a tensor of shape (*batch_shape, 1).

Return type:

torch.Tensor

reduction_function
abstract to_probability_distribution(states, module_output, **policy_kwargs)

Transforms the output of the module into a probability distribution.

This method should not be called for ConditionalScalarEstimator as it outputs scalar values, not probability distributions.

Raises:

NotImplementedError – This method is not implemented for scalar estimators.

Parameters:
Return type:

torch.distributions.Distribution

class gfn.estimators.DiffusionPolicyEstimator(s_dim, module, is_backward=False, debug=False)

Bases: PolicyMixin, Estimator

Base class for diffusion policy estimators.

Parameters:
  • s_dim (int)

  • module (torch.nn.Module)

  • is_backward (bool)

  • debug (bool)

property expected_output_dim: int

Expected output dimension of the module.

Returns:

The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions).

Return type:

int

abstract forward(input)

Forward pass of the module.

Parameters:

input (gfn.states.States) – The input to the module as states.

Returns:

The output of the module, as a tensor of shape (*batch_shape, output_dim).

Return type:

torch.Tensor

s_dim
abstract to_probability_distribution(states, module_output, **policy_kwargs)

Transform the output of the module into a IsotropicGaussian distribution.

Parameters:
  • states (gfn.states.States) – The states to use, states.tensor.shape = (*batch_shape, s_dim + 1).

  • module_output (torch.Tensor) – The output of the module (actions), as a tensor of shape (*batch_shape, s_dim).

  • **policy_kwargs (Any) – Keyword arguments to modify the distribution.

Returns:

A IsotropicGaussian distribution.

Return type:

gfn.utils.distributions.IsotropicGaussian

class gfn.estimators.DiscreteGraphPolicyEstimator(module, preprocessor=None, is_backward=False, debug=False)

Bases: PolicyMixin, LogitBasedEstimator

Forward or backward policy estimators for graph-based environments.

Estimates either, where \(s\) and \(s'\) are graph states: - \(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\) (forward policy) - \(s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}\) (backward policy)

This estimator is designed for graph-based environments where actions modify graphs and states are represented as graphs. The output is a TensorDict containing logits for different action components (action type, node class, edge class, edge index).

Parameters:
module

The neural network module to use.

preprocessor

Preprocessor object that transforms GraphStates objects to tensors.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

property expected_output_dim: int | None

Expected output dimension of the module.

Returns:

None, as the output_dim of a TensorDict is not well-defined.

Return type:

Optional[int]

to_probability_distribution(states, module_output, sf_bias=0.0, temperature=defaultdict(lambda : ...), epsilon=defaultdict(lambda : ...))

Returns a probability distribution given a batch of states and module output.

Similar to DiscretePolicyEstimator.to_probability_distribution(), but handles the complex structure of graph actions through a TensorDict. The method applies masks, biases, temperature scaling, and epsilon-greedy exploration to each action component separately.

Parameters:
  • states (gfn.states.States) – The graph states where the policy is evaluated.

  • module_output (tensordict.TensorDict) – The output of the module as a TensorDict containing logits for different action components.

  • sf_bias (float) – Scalar to subtract from the exit action logit before dividing by temperature.

  • temperature (dict[str, float]) – Dictionary mapping action component keys to temperature values for scaling logits.

  • epsilon (dict[str, float]) – Dictionary mapping action component keys to epsilon values for exploration.

Returns:

A GraphActionDistribution over the graph actions.

Return type:

torch.distributions.Distribution

class gfn.estimators.DiscretePolicyEstimator(module, n_actions, preprocessor=None, is_backward=False, debug=False)

Bases: PolicyMixin, LogitBasedEstimator

Forward or backward policy estimators for discrete environments.

Estimates either: - \(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\) (forward policy) - \(s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}\) (backward policy)

This estimator is designed for discrete environments where actions are represented by integer indices and states have forward/backward masks indicating valid actions.

Parameters:
module

The neural network module to use.

n_actions

Total number of actions in the discrete environment.

preprocessor

Preprocessor object that transforms raw States objects to tensors.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

property expected_output_dim: int

Expected output dimension of the module.

Returns:

n_actions for forward policies, n_actions - 1 for backward policies.

Return type:

int

n_actions
to_probability_distribution(states, module_output, sf_bias=0.0, temperature=1.0, epsilon=0.0)

Returns a Categorical distribution given a batch of states and module output.

This implementation stays in logit/log-prob space for numerical stability. When epsilon > 0, we construct the epsilon-greedy distribution by mixing the original distribution with a uniform distribution and pass the resuling normalized log-probs as logits.

The kwargs may contain parameters to modify a base distribution, for example to encourage exploration.

Parameters:
  • states (gfn.states.DiscreteStates) – The discrete states where the policy is evaluated.

  • module_output (torch.Tensor) – The output of the module as a tensor of shape (*batch_shape, output_dim).

  • sf_bias (float) – Scalar to subtract from the exit action logit before dividing by temperature. Does nothing if set to 0.0 (default), in which case it’s on policy.

  • temperature (float) – Scalar to divide the logits by before softmax. Does nothing if set to 1.0 (default), in which case it’s on policy.

  • epsilon (float) – With probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it’s on policy.

Returns:

A Categorical distribution over the actions.

Return type:

torch.distributions.Categorical

classmethod uniform(n_actions, preprocessor=None)

Create a uniform backward policy estimator for discrete environments.

Outputs equal logits for all actions, resulting in a uniform distribution over valid parent actions (masking is still applied).

Parameters:
  • n_actions (int) – Total number of actions in the discrete environment.

  • preprocessor (gfn.preprocessors.Preprocessor | None) – Preprocessor object that transforms states to tensors. Required because the input dimension depends on the environment.

Returns:

A DiscretePolicyEstimator with is_backward=True and no learnable parameters.

Return type:

DiscretePolicyEstimator

class gfn.estimators.Estimator(module, preprocessor=None, is_backward=False, debug=False)

Bases: abc.ABC, torch.nn.Module

Base class for modules mapping states to distributions or scalar values.

Training a GFlowNet requires parameterizing one or more of the following functions: - \(s \mapsto (\log F(s \rightarrow s'))_{s' \in Children(s)}\) - \(s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}\) - \(s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}\) - \(s \mapsto (\log F(s))_{s \in States}\)

This class is the base class for all such function estimators. The estimators need to encapsulate a nn.Module, which takes a batch of preprocessed states as input and outputs a batch of outputs of the desired shape. When the goal is to represent a probability distribution, the outputs would correspond to the parameters of the distribution, e.g. logits for a categorical distribution for discrete environments.

The call method is used to output logits, or the parameters to distributions. Otherwise, one can overwrite and use the to_probability_distribution() method to directly output a probability distribution.

The preprocessor is also encapsulated in the estimator. These function estimators implement the __call__ method, which takes States objects as inputs and calls the module on the preprocessed states.

Parameters:
module

The neural network module to use. If it is a Tabular module (from gfn.utils.modules), then the environment preprocessor needs to be an EnumPreprocessor.

preprocessor

Preprocessor object that transforms raw States objects to tensors that can be used as input to the module. Optional, defaults to IdentityPreprocessor.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

__repr__()

Returns a string representation of the Estimator.

Returns:

A string summary of the Estimator.

debug = False
property expected_output_dim: int | None
Abstractmethod:

Return type:

Optional[int]

Expected output dimension of the module.

Returns:

The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions).

Return type:

Optional[int]

forward(input)

Forward pass of the module.

Parameters:

input (gfn.states.States) – The input to the module as states.

Returns:

The output of the module, as a tensor of shape (*batch_shape, output_dim).

Return type:

torch.Tensor

is_backward = False
module
preprocessor = None
abstract to_probability_distribution(states, module_output, **policy_kwargs)

Transforms the output of the module into a probability distribution.

The kwargs may contain parameters to modify a base distribution, for example to encourage exploration.

This method is optional for modules that don’t need to output probability distributions (e.g., when estimating logF for flow matching). However, it is required for modules that define policies, as it converts raw module outputs into probability distributions over actions. See DiscretePolicyEstimator for an example using categorical distributions for discrete actions, or BoxPFEstimator for examples using continuous distributions like Beta mixtures.

Parameters:
  • states (gfn.states.States) – The states to use.

  • module_output (torch.Tensor) – The output of the module as a tensor of shape (*batch_shape, output_dim).

  • **policy_kwargs (Any) – Keyword arguments to modify the distribution.

Returns:

A distribution object.

Return type:

torch.distributions.Distribution

class gfn.estimators.LogitBasedEstimator(module, preprocessor=None, is_backward=False, debug=False)

Bases: Estimator

Base class for estimators that output logits.

This class is used to define estimators that output logits, which can be used to construct probability distributions.

Parameters:
module

The neural network module to use.

preprocessor

Preprocessor object that transforms raw States objects to tensors.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

static _compute_logits_for_distribution(logits, masks, sf_index, sf_bias, temperature, epsilon, debug=False)

Return logits to feed a Categorical: - If epsilon == 0: masked, biased, temperature-scaled logits. - Else: normalized log-probs of the epsilon-greedy mixture (valid as logits).

Parameters:
  • logits (torch.Tensor)

  • masks (torch.Tensor)

  • sf_index (int | None)

  • sf_bias (float)

  • temperature (float)

  • epsilon (float)

  • debug (bool)

Return type:

torch.Tensor

static _mix_with_uniform_in_log_space(lsm, masks, epsilon, debug=False)

Compute log((1-eps) p + eps u) in log space.

Parameters:
  • lsm (torch.Tensor)

  • masks (torch.Tensor)

  • epsilon (float)

  • debug (bool)

Return type:

torch.Tensor

static _prepare_logits(logits, masks, sf_index, sf_bias, temperature, debug=False)

Clone and apply mask, bias and temperature to logits.

Parameters:
  • logits (torch.Tensor)

  • masks (torch.Tensor)

  • sf_index (int | None)

  • sf_bias (float)

  • temperature (float)

  • debug (bool)

Return type:

torch.Tensor

static _uniform_log_probs(masks)

Uniform log-probs over valid actions; -inf for invalid.

Parameters:

masks (torch.Tensor)

Return type:

torch.Tensor

class gfn.estimators.PinnedBrownianMotionBackward(s_dim, pb_module, sigma, num_discretization_steps, n_variance_outputs=0, pb_scale_range=0.1)

Bases: DiffusionPolicyEstimator

Base class for diffusion policy estimators.

Parameters:
  • s_dim (int)

  • pb_module (torch.nn.Module)

  • sigma (float)

  • num_discretization_steps (int)

  • n_variance_outputs (int)

  • pb_scale_range (float)

dt
property expected_output_dim: int

Expected output dimension of the module.

Returns:

The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions).

Return type:

int

forward(input)

Forward pass of the module.

Parameters:

input (gfn.states.States)

Return type:

torch.Tensor

n_variance_outputs = 0
pb_scale_range = 0.1
sigma
to_probability_distribution(states, module_output, **policy_kwargs)

Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the previous states under the pinned Brownian motion process, possibly controlled by the output of the backward module. If the module is a fixed backward module, the module_output is a zero vector (no control). Includes optional learned corrections.

Parameters:
  • states (gfn.states.States) – The states to use, states.tensor.shape = (*batch_shape, s_dim + 1).

  • module_output (torch.Tensor) – The output of the module (actions), as a tensor of shape (*batch_shape, s_dim).

  • **policy_kwargs (Any) – Keyword arguments to modify the distribution.

Returns:

A IsotropicGaussian distribution (distribution of the previous states)

Return type:

gfn.utils.distributions.IsotropicGaussian

class gfn.estimators.PinnedBrownianMotionForward(s_dim, pf_module, sigma, num_discretization_steps, n_variance_outputs=0)

Bases: DiffusionPolicyEstimator

Base class for diffusion policy estimators.

Parameters:
  • s_dim (int)

  • pf_module (torch.nn.Module)

  • sigma (float)

  • num_discretization_steps (int)

  • n_variance_outputs (int)

dt
property expected_output_dim: int

Expected output dimension of the module.

Returns:

The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions).

Return type:

int

forward(input)

Forward pass of the module.

Parameters:

input (gfn.states.States) – The input to the module as states.

Returns:

The output of the module, as a tensor of shape (*batch_shape, output_dim).

Return type:

torch.Tensor

n_variance_outputs = 0
num_discretization_steps
sigma
to_probability_distribution(states, module_output, **policy_kwargs)

Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the next states under the pinned Brownian motion controlled by the output of the module.

Parameters:
  • states (gfn.states.States) – The states to use, states.tensor.shape = (*batch_shape, s_dim + 1).

  • module_output (torch.Tensor) – The output of the module (actions), as a tensor of shape (*batch_shape, s_dim).

  • **policy_kwargs (Any) –

    Keyword arguments to modify the distribution. Supported keys: - exploration_std: Optional float or Tensor controlling extra

    exploration noise on top of the base diffusion std. When provided, the extra noise is combined in variance-space (logaddexp) with the base diffusion variance; non-positive values are ignored.

Returns:

A IsotropicGaussian distribution (distribution of the next states)

Return type:

gfn.utils.distributions.IsotropicGaussian

class gfn.estimators.PolicyEstimatorProtocol

Bases: Protocol

Static-typing surface for estimators that are policy-capable.

This protocol captures the methods provided by the PolicyMixin so that external code (e.g., samplers/probability calculators) can use a precise type rather than relying on dynamic attributes. This helps static analyzers avoid false positives like “Tensor is not callable” when calling mixin methods.

compute_dist(states_active, ctx, step_mask=None, **policy_kwargs)
Parameters:
  • states_active (gfn.states.States)

  • ctx (Any)

  • step_mask (Optional[torch.Tensor])

  • policy_kwargs (Any)

Return type:

tuple[torch.distributions.Distribution, Any]

init_context(batch_size, device, conditions=None)
Parameters:
  • batch_size (int)

  • device (torch.device)

  • conditions (Optional[torch.Tensor])

Return type:

Any

is_vectorized: bool
log_probs(actions_active, dist, ctx, step_mask=None, vectorized=False, **kwargs)
Parameters:
  • actions_active (torch.Tensor)

  • dist (torch.distributions.Distribution)

  • ctx (Any)

  • step_mask (Optional[torch.Tensor])

  • vectorized (bool)

  • kwargs (Any)

Return type:

tuple[torch.Tensor, Any]

class gfn.estimators.PolicyMixin

Mixin enabling an Estimator to act as a policy (distribution over actions).

Provides the generic rollout API (init_context, compute_dist, log_probs) directly on the estimator. Standard policies should inherit from this mixin.

compute_dist(states_active, ctx, step_mask=None, save_estimator_outputs=False, **policy_kwargs)

Run the estimator for active rows and build an action Distribution.

Parameters:
  • states_active (gfn.states.States) – The states to run the estimator on.

  • ctx (Any) – The context to run the estimator on.

  • step_mask (Optional[torch.Tensor]) – The mask to slice the conditions to the active subset.

  • save_estimator_outputs (bool) – Whether to save the estimator outputs.

  • **policy_kwargs (Any) – Additional keyword arguments to pass to the estimator.

Returns:

A tuple containing the distribution and the context.

Return type:

tuple[torch.distributions.Distribution, Any]

  • Uses step_mask to slice conditions to the active subset. When step_mask is None, the estimator running in a vectorized context.

  • Saves the raw estimator output in ctx.current_estimator_output for optional recording in record_step.

get_current_estimator_output(ctx)

Expose the most recent per-step estimator output saved during compute.

Parameters:

ctx (Any)

Return type:

Optional[torch.Tensor]

init_context(batch_size, device, conditions=None)

Create a new per-rollout context.

Stores rollout invariants (batch size, device, optional conditions) and initializes empty buffers for per-step artifacts.

Parameters:
  • batch_size (int)

  • device (torch.device)

  • conditions (Optional[torch.Tensor])

Return type:

RolloutContext

property is_vectorized: bool

Used for vectorized probability calculations.

Return type:

bool

log_probs(actions_active, dist, ctx, step_mask=None, vectorized=False, save_logprobs=False)

Compute log-probs, optionally padding back to full batch when non-vectorized.

Parameters:
  • actions_active (torch.Tensor)

  • dist (torch.distributions.Distribution)

  • ctx (Any)

  • step_mask (Optional[torch.Tensor])

  • vectorized (bool)

  • save_logprobs (bool)

Return type:

tuple[torch.Tensor, Any]

gfn.estimators.REDUCTION_FUNCTIONS
class gfn.estimators.RecurrentDiscretePolicyEstimator(module, n_actions, preprocessor=None, is_backward=False, debug=False)

Bases: RecurrentPolicyMixin, DiscretePolicyEstimator

Discrete policy estimator for recurrent architectures with explicit carry.

Many sequence models (e.g., RNN/LSTM/GRU/Transformer in autoregressive mode) maintain a recurrent hidden state (“carry”) that must be threaded through successive calls during sampling. This class formalizes that pattern for GFlowNet policies by:

  • Exposing a forward signature forward(states, carry) -> (logits, carry) so the policy can update and return the next carry at each step.

  • Requiring an init_carry(batch_size, device) method to allocate the initial hidden state for a rollout.

  • Ensuring the per-step output (logits over actions) is derived from the latest token/time step while the internal model may process sequences.

The sampler uses a RecurrentPolicyMixin which calls this estimator with the current carry, updates the carry on every step, and records per-step artifacts. Non-recurrent estimators should use the default PolicyMixin and the standard DiscretePolicyEstimator base class instead.

Notes

  • Forward is intended for on-policy generation; off-policy evaluation over entire trajectories typically requires different batching and masking.

  • init_carry is a hard requirement for compatibility with the recurrent PolicyMixin.

Parameters:
module

The neural network module to use.

n_actions

Total number of actions in the discrete environment.

preprocessor

Preprocessor object that transforms states to tensors.

is_backward

Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents.

forward(states, carry)

Forward pass of the module.

Parameters:
  • states (gfn.states.States) – The input states.

  • carry (dict[str, torch.Tensor]) – The carry from the previous step.

Returns:

The output of the module, as a tensor of shape (*batch_shape, output_dim).

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

init_carry(batch_size, device)
Parameters:
  • batch_size (int)

  • device (torch.device)

Return type:

dict[str, torch.Tensor]

class gfn.estimators.RecurrentPolicyMixin

Bases: PolicyMixin

Mixin for recurrent policies that maintain and update a rollout carry.

compute_dist(states_active, ctx, step_mask=None, save_estimator_outputs=False, **policy_kwargs)

Run estimator with carry and update it.

Differs from the default PolicyMixin by calling estimator(states_active, ctx.carry) -> (est_out, new_carry), storing the updated carry and saving current_estimator_output before building the Distribution.

Parameters:
  • states_active (gfn.states.States)

  • ctx (Any)

  • step_mask (Optional[torch.Tensor])

  • save_estimator_outputs (bool)

  • policy_kwargs (Any)

Return type:

tuple[torch.distributions.Distribution, Any]

init_context(batch_size, device, conditions=None)

Create a new per-rollout context.

Stores rollout invariants (batch size, device, optional conditions) and initializes empty buffers for per-step artifacts.

Parameters:
  • batch_size (int)

  • device (torch.device)

  • conditions (Optional[torch.Tensor])

Return type:

RolloutContext

property is_vectorized: bool

Used for vectorized probability calculations.

Return type:

bool

class gfn.estimators.RolloutContext(batch_size, device, conditions=None)

Structured per‑rollout state owned by estimators.

Holds rollout invariants and optional per‑step buffers; use extras for estimator‑specific fields without changing the class shape.

Parameters:
  • batch_size (int)

  • device (torch.device)

  • conditions (Optional[torch.Tensor])

__slots__ = ('batch_size', 'device', 'conditions', 'carry', 'trajectory_log_probs',...
batch_size
carry = None
conditions = None
current_estimator_output: torch.Tensor | None = None
device
extras: Dict[str, Any]
trajectory_estimator_outputs: List[torch.Tensor] = []
trajectory_log_probs: List[torch.Tensor] = []
class gfn.estimators.ScalarEstimator(module, preprocessor=None, reduction='mean', debug=False)

Bases: Estimator

Class for estimating scalars such as logZ of TB or state/edge flows of DB/SubTB.

Training a GFlowNet sometimes requires the estimation of precise scalar values, such as the partition function (for TB) or state/edge flows (for DB/SubTB). This Estimator is designed for those cases.

Parameters:
module

The neural network module to use. This doesn’t have to directly output a scalar. If it does not, reduction will be used to aggregate the outputs of the module into a single scalar.

preprocessor

Preprocessor object that transforms raw States objects to tensors that can be used as input to the module.

is_backward

Always False for ScalarEstimator (since it’s direction-agnostic).

reduction_function

Function used to reduce multi-dimensional outputs to scalars.

_calculate_module_output(input)
Parameters:

input (gfn.states.States)

Return type:

torch.Tensor

property expected_output_dim: int

Expected output dimension of the module.

Returns:

Always 1, as this estimator outputs scalar values.

Return type:

int

forward(input)

Forward pass of the module.

Parameters:

input (gfn.states.States) – The input to the module as states.

Returns:

The output of the module, as a tensor of shape (*batch_shape, 1).

Return type:

torch.Tensor

reduction_function
gfn.estimators._POLICY_REQUIRED_METHODS = ('init_context', 'compute_dist', 'log_probs')
gfn.estimators.validate_policy_estimator(estimator, name='estimator')

Checks that an estimator implements the PolicyMixin interface.

Parameters:
  • estimator (Any) – The estimator to validate.

  • name (str) – Label for error messages (e.g., “pf”, “pb”).

Raises:

TypeError – If a required method is missing.

Return type:

None