gfn.utils.modules ================= .. py:module:: gfn.utils.modules .. autoapi-nested-parse:: This file contains some examples of modules that can be used with GFN. Attributes ---------- .. autoapisummary:: gfn.utils.modules.ACTIVATION_FNS gfn.utils.modules.logger Classes ------- .. autoapisummary:: gfn.utils.modules.AutoregressiveDiscreteSequenceModel gfn.utils.modules.DiffusionFixedBackwardModule gfn.utils.modules.DiffusionPISGradNetBackward gfn.utils.modules.DiffusionPISGradNetForward gfn.utils.modules.DiffusionPISJointPolicy gfn.utils.modules.DiffusionPISStateEncoding gfn.utils.modules.DiffusionPISTimeEncoding gfn.utils.modules.DiscreteUniform gfn.utils.modules.GraphActionGNN gfn.utils.modules.GraphActionUniform gfn.utils.modules.GraphEdgeActionMLP gfn.utils.modules.GraphScalarMLP gfn.utils.modules.LinearTransformer gfn.utils.modules.MLP gfn.utils.modules.NoisyLinear gfn.utils.modules.RecurrentDiscreteSequenceModel gfn.utils.modules.SinusoidalPositionalEmbedding gfn.utils.modules.Tabular gfn.utils.modules.TransformerDiscreteSequenceModel gfn.utils.modules.UniformModule gfn.utils.modules._AutoregressiveTransformerBlock Functions --------- .. autoapisummary:: gfn.utils.modules.sinusoidal_position_encoding Module Contents --------------- .. py:data:: ACTIVATION_FNS .. py:class:: AutoregressiveDiscreteSequenceModel Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module` Helper class that provides a standard way to create an ABC using inheritance. .. py:method:: forward(x, carry) :abstractmethod: Compute the logits for the next tokens in the sequence. :param x: (B, T) tensor of input token indices where ``T`` is the number of newly supplied timesteps (``T`` may be 1 for incremental decoding). :type x: torch.Tensor :param carry: Carry from previous steps for recurrent processing (e.g., hidden states). :type carry: dict[str, torch.Tensor] :returns: Logits for the next token at each supplied timestep with shape (B, T, vocab) and updated carry. :rtype: tuple[torch.Tensor, dict[str, torch.Tensor]] .. py:method:: init_carry(batch_size, device) :abstractmethod: Initialize the carry for the sequence model. :param batch_size: Batch size. :type batch_size: int :param device: Device to allocate carry tensors on. :type device: torch.device :returns: Initialized carry. :rtype: dict[str, torch.Tensor] .. py:property:: vocab_size :type: int :abstractmethod: Size of the vocabulary (excluding BOS token). .. py:class:: DiffusionFixedBackwardModule(s_dim) Bases: :py:obj:`torch.nn.Module` Fixed Backward Module for DiffusionPISGradNet. .. attribute:: input_dim The dimension of the input. .. py:method:: forward(preprocessed_states) Forward pass of the module. :param preprocessed_states: The preprocessed states (shape: (*batch_shape, s_dim + 1)) :returns: (*batch_shape, s_dim)). :rtype: The output of the module (shape .. py:attribute:: input_dim .. py:class:: DiffusionPISGradNetBackward(s_dim, harmonics_dim = 64, t_emb_dim = 64, s_emb_dim = 64, hidden_dim = 64, joint_layers = 2, zero_init = False, clipping = False, gfn_clip = 10000.0, pb_scale_range = 0.1, log_var_range = 4.0, learn_variance = True) Bases: :py:obj:`torch.nn.Module` Learnable backward correction module (PIS-style) for diffusion. Produces mean and optional log-std corrections that are tanh-scaled by `pb_scale_range` to stay close to the analytic Brownian bridge. .. py:attribute:: clipping :value: False .. py:method:: forward(preprocessed_states) .. py:attribute:: gfn_clip :value: 10000.0 .. py:attribute:: harmonics_dim :value: 64 .. py:attribute:: hidden_dim :value: 64 .. py:attribute:: joint_layers :value: 2 .. py:attribute:: joint_model .. py:attribute:: learn_variance :value: True .. py:attribute:: log_var_range :value: 4.0 .. py:attribute:: out_dim .. py:attribute:: pb_scale_range :value: 0.1 .. py:attribute:: s_dim .. py:attribute:: s_emb_dim :value: 64 .. py:attribute:: s_model .. py:attribute:: t_emb_dim :value: 64 .. py:attribute:: t_model .. py:attribute:: zero_init :value: False .. py:class:: DiffusionPISGradNetForward(s_dim, harmonics_dim = 64, t_emb_dim = 64, s_emb_dim = 64, hidden_dim = 64, joint_layers = 2, zero_init = False, clipping = False, gfn_clip = 10000.0, t_scale = 1.0, log_var_range = 4.0, learn_variance = False) Bases: :py:obj:`torch.nn.Module` PISGradNet for diffusion sampling. This architecture was first introduced in Path Integral Sampler (PIS) (https://arxiv.org/abs/2111.15141) and adapted for GFlowNet-based training by Sendera et al. (https://arxiv.org/abs/2508.03044). .. attribute:: s_dim The dimension of the states. .. attribute:: harmonics_dim The dimension of the Fourier features. .. attribute:: t_emb_dim The dimension of the time embedding. .. attribute:: s_emb_dim The dimension of the state embedding. .. attribute:: hidden_dim The dimension of the hidden layers. .. attribute:: joint_layers The number of layers in the joint policy. .. attribute:: zero_init Whether to initialize the weights and biases of the final layer to zero. .. attribute:: out_dim The dimension of the output. .. attribute:: t_model The time encoding module. .. attribute:: s_model The state encoding module. .. attribute:: joint_model The joint policy module. .. py:attribute:: clipping :value: False .. py:method:: forward(preprocessed_states) Forward pass of the module. :param preprocessed_states: The preprocessed states (shape: (*batch_shape, s_dim + 1)) :returns: (*batch_shape, s_dim)). :rtype: The output of the module (shape .. py:attribute:: gfn_clip :value: 10000.0 .. py:attribute:: harmonics_dim :value: 64 .. py:attribute:: hidden_dim :value: 64 .. py:attribute:: input_dim .. py:attribute:: joint_layers :value: 2 .. py:attribute:: joint_model .. py:attribute:: learn_variance :value: False .. py:attribute:: log_var_range :value: 4.0 .. py:attribute:: out_dim .. py:attribute:: s_dim .. py:attribute:: s_emb_dim :value: 64 .. py:attribute:: s_model .. py:attribute:: t_emb_dim :value: 64 .. py:attribute:: t_model .. py:attribute:: t_scale :value: 1.0 .. py:attribute:: zero_init :value: False .. py:class:: DiffusionPISJointPolicy(s_emb_dim, hidden_dim, out_dim, num_layers, zero_init = False) Bases: :py:obj:`torch.nn.Module` Joint Policy Module for DiffusionPISGradNet. See DiffusionPISGradNet for more details. .. py:method:: forward(s_emb, t_emb) .. py:attribute:: model .. py:class:: DiffusionPISStateEncoding(x_dim, s_emb_dim) Bases: :py:obj:`torch.nn.Module` State Encoding Module for DiffusionPISGradNet. See DiffusionPISGradNet for more details. .. py:method:: forward(s) .. py:attribute:: s_model .. py:class:: DiffusionPISTimeEncoding(harmonics_dim, t_emb_dim, hidden_dim) Bases: :py:obj:`torch.nn.Module` Time Encoding Module for DiffusionPISGradNet. See DiffusionPISGradNet for more details. .. py:method:: forward(t) :param t: torch.Tensor .. py:attribute:: t_model .. py:attribute:: timestep_phase .. py:class:: DiscreteUniform(output_dim) Bases: :py:obj:`UniformModule` Uniform distribution over discrete actions. Backward-compatible alias for ``UniformModule(output_dim, fill_value=0.0)``. Prefer ``UniformModule`` for new code. .. py:class:: GraphActionGNN(num_node_classes, num_edge_classes, directed, embedding_dim = 128, num_conv_layers = 1, is_backward = False) Bases: :py:obj:`torch.nn.Module` Implements a GNN for graph action prediction. .. py:method:: _group_mean(tensor, batch_ptr) :staticmethod: .. py:attribute:: action_type_mlp .. py:attribute:: conv_blks .. py:attribute:: edge_class_mlp .. py:attribute:: embedding .. py:method:: forward(states_tensor) .. py:property:: input_dim .. py:attribute:: is_backward :value: False .. py:attribute:: is_directed .. py:attribute:: node_class_mlp .. py:attribute:: node_index_mlp .. py:attribute:: norm .. py:attribute:: num_conv_layers :value: 1 .. py:attribute:: num_edge_classes .. py:attribute:: num_node_classes .. py:class:: GraphActionUniform(edges_dim, num_edge_classes, num_node_classes) Bases: :py:obj:`torch.nn.Module` Implements a uniform distribution over discrete actions given a graph state. It uses a zero function approximator (a function that always outputs 0) to be used as logits by a DiscretePBEstimator. .. attribute:: output_dim The size of the output space. .. py:attribute:: edges_dim .. py:method:: forward(states_tensor) Forward method for the uniform distribution. :param states_tensor: a batch of states appropriately preprocessed for ingestion by the uniform distribution. :returns: A TensorDict containing logits for each action component, with all values set to 1 to represent a uniform distribution: - GraphActions.ACTION_TYPE_KEY: Tensor of shape [*batch_shape, 3] for the 3 possible action types - GraphActions.EDGE_CLASS_KEY: Tensor of shape [*batch_shape, num_edge_classes] for edge class logits - GraphActions.NODE_CLASS_KEY: Tensor of shape [*batch_shape, num_node_classes] for node class logits - GraphActions.EDGE_INDEX_KEY: Tensor of shape [*batch_shape, edges_dim] for edge index logits .. py:attribute:: input_dim :value: 1 .. py:attribute:: num_edge_classes .. py:attribute:: num_node_classes .. py:class:: GraphEdgeActionMLP(n_nodes, directed, num_node_classes, num_edge_classes, n_hidden_layers = 2, n_hidden_layers_exit = 1, embedding_dim = 128, is_backward = False) Bases: :py:obj:`torch.nn.Module` Network that processes flattened adjacency matrices to predict graph actions. Unlike the GNN-based GraphActionGNN, this module uses standard MLPs to process the entire adjacency matrix as a flattened vector. This approach: 1. Can directly process global graph structure without message passing. 2. May be more effective for small graphs where global patterns are important. 3. Does not require complex graph neural network operations. The module architecture consists of: - An MLP to process the flattened adjacency matrix into an embedding. - An edge MLP that predicts logits for each possible edge action. - An exit MLP that predicts a logit for the exit action. :param n_nodes: Number of nodes in the graph. :param directed: Whether the graph is directed or undirected. :param n_hidden_layers: Number of hidden layers in the MLP for the edge actions. :param n_hidden_layers_exit: Number of hidden layers in the MLP for the exit action. :param embedding_dim: Dimension of internal embeddings. :param is_backward: Whether this is a backward policy. .. py:attribute:: _edges_dim .. py:attribute:: _input_dim .. py:attribute:: _output_dim .. py:attribute:: edge_class_mlp .. py:attribute:: edge_mlp .. py:property:: edges_dim :type: int .. py:attribute:: exit_mlp .. py:attribute:: features_embedding .. py:method:: forward(states_tensor) Forward pass to compute action logits from graph states. Process: 1. Convert the graph representation to adjacency matrices 2. Process the flattened adjacency matrices through the main MLP 3. Predict logits for edge actions and exit action :param states_tensor: A GeometricBatch containing graph state information :returns: A tensor of logits for all possible actions .. py:attribute:: hidden_dim :value: 128 .. py:property:: input_dim :type: int .. py:attribute:: is_backward :value: False .. py:attribute:: is_directed .. py:attribute:: mlp .. py:attribute:: n_nodes .. py:attribute:: node_class_mlp .. py:attribute:: node_index_mlp .. py:attribute:: num_edge_classes .. py:attribute:: num_node_classes .. py:property:: output_dim :type: int .. py:class:: GraphScalarMLP(n_nodes, directed, embedding_dim = 128, n_hidden_layers = 2, n_outputs = 1) Bases: :py:obj:`torch.nn.Module` Graph encoder that maps adjacency structure to n scalar output. .. py:attribute:: backbone .. py:method:: forward(states_tensor) Encode graphs into a scalar per element of the batch. .. py:attribute:: head .. py:attribute:: input_dim .. py:attribute:: is_directed .. py:attribute:: n_nodes .. py:class:: LinearTransformer(dim, depth, max_seq_len, n_heads = 8, causal = False) Bases: :py:obj:`torch.nn.Module` The Linear Transformer module. Implements Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret, ICML 2020. Expresses self-attention as a linear dot-product of kernel feature maps and makes use the associativity property of matrix products to reduce the complexity of the attention computation from O(n^2) to O(n). Implementation from https://github.com/lucidrains/linear-attention-transformer. :param dim: The dimension of the input. :param depth: The depth of the transformer. :param max_seq_len: The maximum sequence length. :param n_heads: The number of attention heads. :param causal: Whether to use causal attention. .. py:method:: forward(x) .. py:attribute:: module .. py:attribute:: output_dim .. py:class:: MLP(input_dim, output_dim, hidden_dim = 256, n_hidden_layers = 2, n_noisy_layers = 0, activation_fn = 'relu', trunk = None, add_layer_norm = False, std_init = 0.1) Bases: :py:obj:`torch.nn.Module` Implements a basic MLP with optional noisy layers for exploration. When `trunk` is provided, the MLP will be a wrapper around the trunk. See `Noisy Networks for Exploration (Fortunato et al., ICLR 2018)` for more details on the noisy layers. .. py:attribute:: _input_dim .. py:attribute:: _output_dim .. py:method:: forward(preprocessed_states) Forward method for the neural network. :param preprocessed_states: a batch of states appropriately preprocessed for ingestion by the MLP. The shape of the tensor should be (*batch_shape, input_dim). Returns: a tensor of shape (*batch_shape, output_dim). .. py:property:: input_dim .. py:property:: output_dim .. py:class:: NoisyLinear(in_features, out_features, bias = True, device = None, dtype = None, std_init = 0.1) Bases: :py:obj:`torch.nn.Linear` Noisy Linear Layer. Presented in "Noisy Networks for Exploration", https://arxiv.org/abs/1706.10295v3 A Noisy Linear Layer is a linear layer with parametric noise added to the weights. This induced stochasticity can be used in RL networks for the agent's policy to aid efficient exploration. The parameters of the noise are learned with gradient descent along with any other remaining network weights. Factorized Gaussian noise is the type of noise usually employed. Taken from torchrl v0.9.2. :param in_features: input features dimension :type in_features: int :param out_features: out features dimension :type out_features: int :param bias: if ``True``, a bias term will be added to the matrix multiplication: Ax + b. Defaults to ``True`` :type bias: bool, optional :param device: device of the layer. Defaults to ``"cpu"`` :type device: DEVICE_TYPING, optional :param dtype: dtype of the parameters. Defaults to ``None`` (default pytorch dtype) :type dtype: torch.dtype, optional :param std_init: initial value of the Gaussian standard deviation before optimization. Defaults to ``0.1`` :type std_init: scalar, optional .. py:method:: _scale_noise(size) .. py:property:: bias :type: torch.Tensor | None .. py:attribute:: in_features .. py:attribute:: out_features .. py:method:: reset_noise() .. py:method:: reset_parameters() .. py:attribute:: std_init :value: 0.1 .. py:property:: weight :type: torch.Tensor .. py:attribute:: weight_mu .. py:attribute:: weight_sigma .. py:class:: RecurrentDiscreteSequenceModel(vocab_size, embedding_dim, hidden_size, num_layers = 1, rnn_type = 'lstm', dropout = 0.0) Bases: :py:obj:`AutoregressiveDiscreteSequenceModel` Helper class that provides a standard way to create an ABC using inheritance. .. py:attribute:: _vocab_size .. py:attribute:: embedding .. py:attribute:: embedding_dim .. py:method:: forward(x, carry) Compute the logits for the next tokens in the sequence. :param x: (B, T) tensor of input token indices where ``T`` is the number of newly supplied timesteps (``T`` may be 1 for incremental decoding). :type x: torch.Tensor :param carry: Carry from previous steps for recurrent processing (e.g., hidden states). :type carry: dict[str, torch.Tensor] :returns: Logits for the next token at each supplied timestep with shape (B, T, vocab) and updated carry. :rtype: tuple[torch.Tensor, dict[str, torch.Tensor]] .. py:attribute:: gru :type: torch.nn.GRU | None .. py:attribute:: hidden_size .. py:method:: init_carry(batch_size, device) Initialize the carry for the sequence model. :param batch_size: Batch size. :type batch_size: int :param device: Device to allocate carry tensors on. :type device: torch.device :returns: Initialized carry. :rtype: dict[str, torch.Tensor] .. py:attribute:: lstm :type: torch.nn.LSTM | None .. py:attribute:: num_layers :value: 1 .. py:attribute:: output_projection .. py:attribute:: rnn_type :value: '' .. py:property:: vocab_size :type: int Size of the vocabulary (excluding BOS token). .. py:class:: SinusoidalPositionalEmbedding(embedding_dim, max_length = 2048, base = 10000.0) Bases: :py:obj:`torch.nn.Module` Sinusoidal positional embeddings for transformer-style models. The module caches a precomputed table of embeddings and extends it on demand. Forward accepts either a sequence length or explicit position indices. .. py:attribute:: _pe :type: torch.Tensor .. py:attribute:: base .. py:attribute:: embedding_dim .. py:method:: forward(positions = None, seq_len = None) Look up positional embeddings. :param positions: Optional tensor of position indices. Can have any shape, and the returned embeddings will append ``embedding_dim`` to that shape. Defaults to ``None``. :param seq_len: Optional sequence length. When provided, returns the first ``seq_len`` embeddings from the table. :returns: Tensor of positional embeddings on the same device/dtype as the cached table. :raises ValueError: If both or neither of ``positions`` and ``seq_len`` are provided, or if indices exceed the cached range. .. py:property:: pe :type: torch.Tensor Return the cached positional embedding table. .. py:class:: Tabular(n_states, output_dim) Bases: :py:obj:`torch.nn.Module` Implements a tabular policy. This class is only compatible with the EnumPreprocessor. .. attribute:: table a tensor with dimensions [n_states, output_dim]. .. attribute:: device the device that holds this policy. .. py:attribute:: device :value: None .. py:method:: forward(preprocessed_states) Forward method for the tabular policy. :param preprocessed_states: a batch of states appropriately preprocessed for ingestion by the tabular policy. The shape of the tensor should be (*batch_shape, 1). Returns: a tensor of shape (*batch_shape, output_dim). .. py:attribute:: table .. py:class:: TransformerDiscreteSequenceModel(vocab_size, embedding_dim, num_heads, ff_hidden_dim, num_layers, max_position_embeddings, dropout = 0.0, positional_embedding = 'learned') Bases: :py:obj:`AutoregressiveDiscreteSequenceModel` Helper class that provides a standard way to create an ABC using inheritance. .. py:attribute:: _positional_embedding_type :value: 'learned' .. py:attribute:: _vocab_size .. py:attribute:: embedding_dim .. py:attribute:: embedding_dropout .. py:attribute:: ff_hidden_dim .. py:attribute:: final_norm .. py:method:: forward(x, carry) Compute the logits for the next tokens in the sequence. :param x: (B, T) tensor of input token indices where ``T`` is the number of newly supplied timesteps (``T`` may be 1 for incremental decoding). :type x: torch.Tensor :param carry: Carry from previous steps for recurrent processing (e.g., hidden states). :type carry: dict[str, torch.Tensor] :returns: Logits for the next token at each supplied timestep with shape (B, T, vocab) and updated carry. :rtype: tuple[torch.Tensor, dict[str, torch.Tensor]] .. py:attribute:: head_dim .. py:method:: init_carry(batch_size, device) Initialize the carry for the sequence model. :param batch_size: Batch size. :type batch_size: int :param device: Device to allocate carry tensors on. :type device: torch.device :returns: Initialized carry. :rtype: dict[str, torch.Tensor] .. py:attribute:: key_names .. py:attribute:: layers .. py:attribute:: max_position_embeddings .. py:attribute:: num_heads .. py:attribute:: num_layers .. py:attribute:: output_projection .. py:attribute:: token_embedding .. py:attribute:: value_names .. py:property:: vocab_size :type: int Size of the vocabulary (excluding BOS token). .. py:class:: UniformModule(output_dim, input_dim = None, fill_value = 0.0, skip_normalization = False) Bases: :py:obj:`torch.nn.Module` Constant-output module for non-learned (uniform/fixed) policies. Outputs a constant tensor for all inputs. Plug into any ``Estimator`` as the ``module`` argument to get a non-learned policy. Typical fill values: * ``0.0`` — equal logits → uniform categorical (discrete environments). * ``1.0`` — fixed params for sigmoid-based continuous estimators. .. attribute:: output_dim Output dimension matching the estimator's expected_output_dim. .. attribute:: fill_value Constant value to fill the output tensor. .. attribute:: skip_normalization If True, signals to estimators that outputs should be used directly without normalization transforms (e.g. sigmoid scaling of concentration parameters). .. py:attribute:: fill_value :value: 0.0 .. py:method:: forward(preprocessed_states) Return a constant tensor of shape ``(*batch_shape, output_dim)``. :param preprocessed_states: Tensor of shape ``(*batch_shape, input_dim)``. :returns: Tensor of shape ``(*batch_shape, output_dim)`` filled with ``fill_value``. .. py:attribute:: output_dim .. py:attribute:: skip_normalization :value: False .. py:class:: _AutoregressiveTransformerBlock(embed_dim, num_heads, ff_hidden_dim, dropout) Bases: :py:obj:`torch.nn.Module` .. py:attribute:: attn_dropout .. py:attribute:: embed_dim .. py:attribute:: ff_dropout .. py:method:: forward(hidden, key_carry, value_carry) .. py:attribute:: head_dim .. py:attribute:: k_proj .. py:attribute:: linear1 .. py:attribute:: linear2 .. py:attribute:: norm1 .. py:attribute:: norm2 .. py:attribute:: num_heads .. py:attribute:: out_proj .. py:attribute:: q_proj .. py:attribute:: residual_dropout .. py:attribute:: v_proj .. py:data:: logger .. py:function:: sinusoidal_position_encoding(length, embedding_dim, base = 10000.0) Create 1D sinusoidal positional embeddings. :param length: Number of positions to encode. Must be non-negative. :param embedding_dim: Dimensionality of each embedding. Must be positive. :param base: Exponential base used to compute the angular frequencies. :returns: A ``(length, embedding_dim)`` tensor of sinusoidal encodings. :raises ValueError: If ``length`` is negative, ``embedding_dim`` is not positive, or ``base`` is not positive.