gfn.estimators ============== .. py:module:: gfn.estimators Attributes ---------- .. autoapisummary:: gfn.estimators.REDUCTION_FUNCTIONS gfn.estimators._POLICY_REQUIRED_METHODS Classes ------- .. autoapisummary:: gfn.estimators.ConditionalDiscretePolicyEstimator gfn.estimators.ConditionalLogZEstimator gfn.estimators.ConditionalScalarEstimator gfn.estimators.DiffusionPolicyEstimator gfn.estimators.DiscreteGraphPolicyEstimator gfn.estimators.DiscretePolicyEstimator gfn.estimators.Estimator gfn.estimators.LogitBasedEstimator gfn.estimators.PinnedBrownianMotionBackward gfn.estimators.PinnedBrownianMotionForward gfn.estimators.PolicyEstimatorProtocol gfn.estimators.PolicyMixin gfn.estimators.RecurrentDiscretePolicyEstimator gfn.estimators.RecurrentPolicyMixin gfn.estimators.RolloutContext gfn.estimators.ScalarEstimator Functions --------- .. autoapisummary:: gfn.estimators.validate_policy_estimator Module Contents --------------- .. py:class:: ConditionalDiscretePolicyEstimator(state_module, condition_module, final_module, n_actions, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`DiscretePolicyEstimator` Conditional forward or backward policy estimators for discrete environments. Estimates either, with condition $c$: - $s \mapsto (P_F(s' \mid s, c))_{s' \in Children(s)}$ (conditional forward policy) - $s' \mapsto (P_B(s \mid s', c))_{s \in Parents(s')}$ (conditional backward policy) This estimator is designed for discrete environments where the policy depends on both the state and some condition information. It uses a multi-module architecture where states and conditions are processed separately before being combined. .. attribute:: module The neural network module for state processing. .. attribute:: condition_module The neural network module for condition processing. .. attribute:: final_module The neural network module that combines state and condition. .. attribute:: n_actions Total number of actions in the discrete environment. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:method:: _forward_trunk(states, conditions) Forward pass of the trunk of the module. This method processes the states and conditions inputs separately, then combines them through the final module. :param states: The input states. :param conditions: The condition tensor. :returns: The output of the trunk of the module, as a tensor of shape (*batch_shape, output_dim). .. py:attribute:: condition_module .. py:attribute:: final_module .. py:method:: forward(states, conditions) Forward pass of the module. :param states: The input states. :param conditions: The condition tensor. :returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). .. py:attribute:: n_actions .. py:class:: ConditionalLogZEstimator(module, reduction = 'mean') Bases: :py:obj:`ScalarEstimator` Conditional logZ estimator. This estimator is used to estimate the logZ of a GFlowNet from a conditions tensor. Since the conditions are given as a tensor, it does not have a preprocessor. Reduction is used to aggregate the outputs of the module into a single scalar. .. attribute:: module The neural network module to use. .. attribute:: reduction String name of one of the REDUCTION_FUNCTIONS keys. .. py:method:: _calculate_module_output(input) .. py:class:: ConditionalScalarEstimator(state_module, condition_module, final_module, preprocessor = None, reduction = 'mean', debug = False) Bases: :py:obj:`ConditionalDiscretePolicyEstimator` Class for conditionally estimating scalars (logZ, DB/SubTB state logF). Similar to `ScalarEstimator`, the function approximator used for `final_module` need not directly output a scalar. If it does not, `reduction` will be used to aggregate the outputs of the module into a single scalar. .. attribute:: module The neural network module for state processing. .. attribute:: condition_module The neural network module for condition processing. .. attribute:: final_module The neural network module that combines state and condition. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors. .. attribute:: is_backward Always False for ConditionalScalarEstimator (since it's direction-agnostic). .. attribute:: reduction_function Function used to reduce multi-dimensional outputs to scalars. .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: Always 1, as this estimator outputs scalar values. .. py:method:: forward(states, conditions) Forward pass of the module. :param states: The input states. :param conditions: The condition tensor. :returns: The output of the module, as a tensor of shape (*batch_shape, 1). .. py:attribute:: reduction_function .. py:method:: to_probability_distribution(states, module_output, **policy_kwargs) :abstractmethod: Transforms the output of the module into a probability distribution. This method should not be called for ConditionalScalarEstimator as it outputs scalar values, not probability distributions. :raises NotImplementedError: This method is not implemented for scalar estimators. .. py:class:: DiffusionPolicyEstimator(s_dim, module, is_backward = False, debug = False) Bases: :py:obj:`PolicyMixin`, :py:obj:`Estimator` Base class for diffusion policy estimators. .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions). .. py:method:: forward(input) :abstractmethod: Forward pass of the module. :param input: The input to the module as states. :returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). .. py:attribute:: s_dim .. py:method:: to_probability_distribution(states, module_output, **policy_kwargs) :abstractmethod: Transform the output of the module into a IsotropicGaussian distribution. :param states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). :param module_output: The output of the module (actions), as a tensor of shape (*batch_shape, s_dim). :param \*\*policy_kwargs: Keyword arguments to modify the distribution. :returns: A IsotropicGaussian distribution. .. py:class:: DiscreteGraphPolicyEstimator(module, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`PolicyMixin`, :py:obj:`LogitBasedEstimator` Forward or backward policy estimators for graph-based environments. Estimates either, where $s$ and $s'$ are graph states: - $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$ (forward policy) - $s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}$ (backward policy) This estimator is designed for graph-based environments where actions modify graphs and states are represented as graphs. The output is a TensorDict containing logits for different action components (action type, node class, edge class, edge index). .. attribute:: module The neural network module to use. .. attribute:: preprocessor Preprocessor object that transforms GraphStates objects to tensors. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:property:: expected_output_dim :type: Optional[int] Expected output dimension of the module. :returns: None, as the output_dim of a TensorDict is not well-defined. .. py:method:: to_probability_distribution(states, module_output, sf_bias = 0.0, temperature = defaultdict(lambda: 1.0), epsilon = defaultdict(lambda: 0.0)) Returns a probability distribution given a batch of states and module output. Similar to `DiscretePolicyEstimator.to_probability_distribution()`, but handles the complex structure of graph actions through a TensorDict. The method applies masks, biases, temperature scaling, and epsilon-greedy exploration to each action component separately. :param states: The graph states where the policy is evaluated. :param module_output: The output of the module as a TensorDict containing logits for different action components. :param sf_bias: Scalar to subtract from the exit action logit before dividing by temperature. :param temperature: Dictionary mapping action component keys to temperature values for scaling logits. :param epsilon: Dictionary mapping action component keys to epsilon values for exploration. :returns: A GraphActionDistribution over the graph actions. .. py:class:: DiscretePolicyEstimator(module, n_actions, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`PolicyMixin`, :py:obj:`LogitBasedEstimator` Forward or backward policy estimators for discrete environments. Estimates either: - $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$ (forward policy) - $s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}$ (backward policy) This estimator is designed for discrete environments where actions are represented by integer indices and states have forward/backward masks indicating valid actions. .. attribute:: module The neural network module to use. .. attribute:: n_actions Total number of actions in the discrete environment. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: n_actions for forward policies, n_actions - 1 for backward policies. .. py:attribute:: n_actions .. py:method:: to_probability_distribution(states, module_output, sf_bias = 0.0, temperature = 1.0, epsilon = 0.0) Returns a Categorical distribution given a batch of states and module output. This implementation stays in logit/log-prob space for numerical stability. When epsilon > 0, we construct the epsilon-greedy distribution by mixing the original distribution with a uniform distribution and pass the resuling normalized log-probs as logits. The kwargs may contain parameters to modify a base distribution, for example to encourage exploration. :param states: The discrete states where the policy is evaluated. :param module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). :param sf_bias: Scalar to subtract from the exit action logit before dividing by temperature. Does nothing if set to 0.0 (default), in which case it's on policy. :param temperature: Scalar to divide the logits by before softmax. Does nothing if set to 1.0 (default), in which case it's on policy. :param epsilon: With probability epsilon, a random action is chosen. Does nothing if set to 0.0 (default), in which case it's on policy. :returns: A Categorical distribution over the actions. .. py:method:: uniform(n_actions, preprocessor = None) :classmethod: Create a uniform backward policy estimator for discrete environments. Outputs equal logits for all actions, resulting in a uniform distribution over valid parent actions (masking is still applied). :param n_actions: Total number of actions in the discrete environment. :param preprocessor: Preprocessor object that transforms states to tensors. Required because the input dimension depends on the environment. :returns: A ``DiscretePolicyEstimator`` with ``is_backward=True`` and no learnable parameters. .. py:class:: Estimator(module, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module` Base class for modules mapping states to distributions or scalar values. Training a GFlowNet requires parameterizing one or more of the following functions: - $s \mapsto (\log F(s \rightarrow s'))_{s' \in Children(s)}$ - $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$ - $s' \mapsto (P_B(s \mid s'))_{s \in Parents(s')}$ - $s \mapsto (\log F(s))_{s \in States}$ This class is the base class for all such function estimators. The estimators need to encapsulate a `nn.Module`, which takes a batch of preprocessed states as input and outputs a batch of outputs of the desired shape. When the goal is to represent a probability distribution, the outputs would correspond to the parameters of the distribution, e.g. logits for a categorical distribution for discrete environments. The call method is used to output logits, or the parameters to distributions. Otherwise, one can overwrite and use the `to_probability_distribution()` method to directly output a probability distribution. The preprocessor is also encapsulated in the estimator. These function estimators implement the `__call__` method, which takes `States` objects as inputs and calls the module on the preprocessed states. .. attribute:: module The neural network module to use. If it is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors that can be used as input to the module. Optional, defaults to `IdentityPreprocessor`. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:method:: __repr__() Returns a string representation of the Estimator. :returns: A string summary of the Estimator. .. py:attribute:: debug :value: False .. py:property:: expected_output_dim :type: Optional[int] :abstractmethod: Expected output dimension of the module. :returns: The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions). .. py:method:: forward(input) Forward pass of the module. :param input: The input to the module as states. :returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). .. py:attribute:: is_backward :value: False .. py:attribute:: module .. py:attribute:: preprocessor :value: None .. py:method:: to_probability_distribution(states, module_output, **policy_kwargs) :abstractmethod: Transforms the output of the module into a probability distribution. The kwargs may contain parameters to modify a base distribution, for example to encourage exploration. This method is optional for modules that don't need to output probability distributions (e.g., when estimating logF for flow matching). However, it is required for modules that define policies, as it converts raw module outputs into probability distributions over actions. See `DiscretePolicyEstimator` for an example using categorical distributions for discrete actions, or `BoxPFEstimator` for examples using continuous distributions like Beta mixtures. :param states: The states to use. :param module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). :param \*\*policy_kwargs: Keyword arguments to modify the distribution. :returns: A distribution object. .. py:class:: LogitBasedEstimator(module, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`Estimator` Base class for estimators that output logits. This class is used to define estimators that output logits, which can be used to construct probability distributions. .. attribute:: module The neural network module to use. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:method:: _compute_logits_for_distribution(logits, masks, sf_index, sf_bias, temperature, epsilon, debug = False) :staticmethod: Return logits to feed a Categorical: - If epsilon == 0: masked, biased, temperature-scaled logits. - Else: normalized log-probs of the epsilon-greedy mixture (valid as logits). .. py:method:: _mix_with_uniform_in_log_space(lsm, masks, epsilon, debug = False) :staticmethod: Compute log((1-eps) p + eps u) in log space. .. py:method:: _prepare_logits(logits, masks, sf_index, sf_bias, temperature, debug = False) :staticmethod: Clone and apply mask, bias and temperature to logits. .. py:method:: _uniform_log_probs(masks) :staticmethod: Uniform log-probs over valid actions; -inf for invalid. .. py:class:: PinnedBrownianMotionBackward(s_dim, pb_module, sigma, num_discretization_steps, n_variance_outputs = 0, pb_scale_range = 0.1) Bases: :py:obj:`DiffusionPolicyEstimator` Base class for diffusion policy estimators. .. py:attribute:: dt .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions). .. py:method:: forward(input) Forward pass of the module. .. py:attribute:: n_variance_outputs :value: 0 .. py:attribute:: pb_scale_range :value: 0.1 .. py:attribute:: sigma .. py:method:: to_probability_distribution(states, module_output, **policy_kwargs) Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the previous states under the pinned Brownian motion process, possibly controlled by the output of the backward module. If the module is a fixed backward module, the `module_output` is a zero vector (no control). Includes optional learned corrections. :param states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). :param module_output: The output of the module (actions), as a tensor of shape (*batch_shape, s_dim). :param \*\*policy_kwargs: Keyword arguments to modify the distribution. :returns: A IsotropicGaussian distribution (distribution of the previous states) .. py:class:: PinnedBrownianMotionForward(s_dim, pf_module, sigma, num_discretization_steps, n_variance_outputs = 0) Bases: :py:obj:`DiffusionPolicyEstimator` Base class for diffusion policy estimators. .. py:attribute:: dt .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: The expected output dimension of the module, or None if the output dimension is not well-defined (e.g., when the output is a TensorDict for GraphActions). .. py:method:: forward(input) Forward pass of the module. :param input: The input to the module as states. :returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). .. py:attribute:: n_variance_outputs :value: 0 .. py:attribute:: num_discretization_steps .. py:attribute:: sigma .. py:method:: to_probability_distribution(states, module_output, **policy_kwargs) Transform the output of the module into a IsotropicGaussian distribution, which is the distribution of the next states under the pinned Brownian motion controlled by the output of the module. :param states: The states to use, states.tensor.shape = (*batch_shape, s_dim + 1). :param module_output: The output of the module (actions), as a tensor of shape (*batch_shape, s_dim). :param \*\*policy_kwargs: Keyword arguments to modify the distribution. Supported keys: - exploration_std: Optional float or Tensor controlling extra exploration noise on top of the base diffusion std. When provided, the extra noise is combined in variance-space (logaddexp) with the base diffusion variance; non-positive values are ignored. :returns: A IsotropicGaussian distribution (distribution of the next states) .. py:class:: PolicyEstimatorProtocol Bases: :py:obj:`Protocol` Static-typing surface for estimators that are policy-capable. This protocol captures the methods provided by the PolicyMixin so that external code (e.g., samplers/probability calculators) can use a precise type rather than relying on dynamic attributes. This helps static analyzers avoid false positives like "Tensor is not callable" when calling mixin methods. .. py:method:: compute_dist(states_active, ctx, step_mask = None, **policy_kwargs) .. py:method:: init_context(batch_size, device, conditions = None) .. py:attribute:: is_vectorized :type: bool .. py:method:: log_probs(actions_active, dist, ctx, step_mask = None, vectorized = False, **kwargs) .. py:class:: PolicyMixin Mixin enabling an `Estimator` to act as a policy (distribution over actions). Provides the generic rollout API (`init_context`, `compute_dist`, `log_probs`) directly on the estimator. Standard policies should inherit from this mixin. .. py:method:: compute_dist(states_active, ctx, step_mask = None, save_estimator_outputs = False, **policy_kwargs) Run the estimator for active rows and build an action Distribution. :param states_active: The states to run the estimator on. :param ctx: The context to run the estimator on. :param step_mask: The mask to slice the conditions to the active subset. :param save_estimator_outputs: Whether to save the estimator outputs. :param \*\*policy_kwargs: Additional keyword arguments to pass to the estimator. :returns: A tuple containing the distribution and the context. - Uses `step_mask` to slice conditions to the active subset. When `step_mask` is None, the estimator running in a vectorized context. - Saves the raw estimator output in `ctx.current_estimator_output` for optional recording in `record_step`. .. py:method:: get_current_estimator_output(ctx) Expose the most recent per-step estimator output saved during `compute`. .. py:method:: init_context(batch_size, device, conditions = None) Create a new per-rollout context. Stores rollout invariants (batch size, device, optional conditions) and initializes empty buffers for per-step artifacts. .. py:property:: is_vectorized :type: bool Used for vectorized probability calculations. .. py:method:: log_probs(actions_active, dist, ctx, step_mask = None, vectorized = False, save_logprobs = False) Compute log-probs, optionally padding back to full batch when non-vectorized. .. py:data:: REDUCTION_FUNCTIONS .. py:class:: RecurrentDiscretePolicyEstimator(module, n_actions, preprocessor = None, is_backward = False, debug = False) Bases: :py:obj:`RecurrentPolicyMixin`, :py:obj:`DiscretePolicyEstimator` Discrete policy estimator for recurrent architectures with explicit carry. Many sequence models (e.g., RNN/LSTM/GRU/Transformer in autoregressive mode) maintain a recurrent hidden state ("carry") that must be threaded through successive calls during sampling. This class formalizes that pattern for GFlowNet policies by: - Exposing a forward signature ``forward(states, carry) -> (logits, carry)`` so the policy can update and return the next carry at each step. - Requiring an ``init_carry(batch_size, device)`` method to allocate the initial hidden state for a rollout. - Ensuring the per-step output (``logits`` over actions) is derived from the latest token/time step while the internal model may process sequences. The sampler uses a ``RecurrentPolicyMixin`` which calls this estimator with the current carry, updates the carry on every step, and records per-step artifacts. Non-recurrent estimators should use the default PolicyMixin and the standard ``DiscretePolicyEstimator`` base class instead. .. rubric:: Notes - Forward is intended for on-policy generation; off-policy evaluation over entire trajectories typically requires different batching and masking. - ``init_carry`` is a hard requirement for compatibility with the recurrent PolicyMixin. .. attribute:: module The neural network module to use. .. attribute:: n_actions Total number of actions in the discrete environment. .. attribute:: preprocessor Preprocessor object that transforms states to tensors. .. attribute:: is_backward Flag indicating whether this estimator is for backward policy, i.e., is used for predicting probability distributions over parents. .. py:method:: forward(states, carry) Forward pass of the module. :param states: The input states. :param carry: The carry from the previous step. :returns: The output of the module, as a tensor of shape (*batch_shape, output_dim). .. py:method:: init_carry(batch_size, device) .. py:class:: RecurrentPolicyMixin Bases: :py:obj:`PolicyMixin` Mixin for recurrent policies that maintain and update a rollout carry. .. py:method:: compute_dist(states_active, ctx, step_mask = None, save_estimator_outputs = False, **policy_kwargs) Run estimator with carry and update it. Differs from the default PolicyMixin by calling `estimator(states_active, ctx.carry) -> (est_out, new_carry)`, storing the updated carry and saving `current_estimator_output` before building the Distribution. .. py:method:: init_context(batch_size, device, conditions = None) Create a new per-rollout context. Stores rollout invariants (batch size, device, optional conditions) and initializes empty buffers for per-step artifacts. .. py:property:: is_vectorized :type: bool Used for vectorized probability calculations. .. py:class:: RolloutContext(batch_size, device, conditions = None) Structured per‑rollout state owned by estimators. Holds rollout invariants and optional per‑step buffers; use ``extras`` for estimator‑specific fields without changing the class shape. .. py:attribute:: __slots__ :value: ('batch_size', 'device', 'conditions', 'carry', 'trajectory_log_probs',... .. py:attribute:: batch_size .. py:attribute:: carry :value: None .. py:attribute:: conditions :value: None .. py:attribute:: current_estimator_output :type: Optional[torch.Tensor] :value: None .. py:attribute:: device .. py:attribute:: extras :type: Dict[str, Any] .. py:attribute:: trajectory_estimator_outputs :type: List[torch.Tensor] :value: [] .. py:attribute:: trajectory_log_probs :type: List[torch.Tensor] :value: [] .. py:class:: ScalarEstimator(module, preprocessor = None, reduction = 'mean', debug = False) Bases: :py:obj:`Estimator` Class for estimating scalars such as logZ of TB or state/edge flows of DB/SubTB. Training a GFlowNet sometimes requires the estimation of precise scalar values, such as the partition function (for TB) or state/edge flows (for DB/SubTB). This Estimator is designed for those cases. .. attribute:: module The neural network module to use. This doesn't have to directly output a scalar. If it does not, `reduction` will be used to aggregate the outputs of the module into a single scalar. .. attribute:: preprocessor Preprocessor object that transforms raw States objects to tensors that can be used as input to the module. .. attribute:: is_backward Always False for ScalarEstimator (since it's direction-agnostic). .. attribute:: reduction_function Function used to reduce multi-dimensional outputs to scalars. .. py:method:: _calculate_module_output(input) .. py:property:: expected_output_dim :type: int Expected output dimension of the module. :returns: Always 1, as this estimator outputs scalar values. .. py:method:: forward(input) Forward pass of the module. :param input: The input to the module as states. :returns: The output of the module, as a tensor of shape (*batch_shape, 1). .. py:attribute:: reduction_function .. py:data:: _POLICY_REQUIRED_METHODS :value: ('init_context', 'compute_dist', 'log_probs') .. py:function:: validate_policy_estimator(estimator, name = 'estimator') Checks that an estimator implements the PolicyMixin interface. :param estimator: The estimator to validate. :param name: Label for error messages (e.g., "pf", "pb"). :raises TypeError: If a required method is missing.