gfn.utils.modules

This file contains some examples of modules that can be used with GFN.

Attributes

ACTIVATION_FNS

logger

Classes

AutoregressiveDiscreteSequenceModel

Helper class that provides a standard way to create an ABC using

DiffusionFixedBackwardModule

Fixed Backward Module for DiffusionPISGradNet.

DiffusionPISGradNetBackward

Learnable backward correction module (PIS-style) for diffusion.

DiffusionPISGradNetForward

PISGradNet for diffusion sampling.

DiffusionPISJointPolicy

Joint Policy Module for DiffusionPISGradNet.

DiffusionPISStateEncoding

State Encoding Module for DiffusionPISGradNet.

DiffusionPISTimeEncoding

Time Encoding Module for DiffusionPISGradNet.

DiscreteUniform

Uniform distribution over discrete actions.

GraphActionGNN

Implements a GNN for graph action prediction.

GraphActionUniform

Implements a uniform distribution over discrete actions given a graph state.

GraphEdgeActionMLP

Network that processes flattened adjacency matrices to predict graph actions.

GraphScalarMLP

Graph encoder that maps adjacency structure to n scalar output.

LinearTransformer

The Linear Transformer module.

MLP

Implements a basic MLP with optional noisy layers for exploration.

NoisyLinear

Noisy Linear Layer.

RecurrentDiscreteSequenceModel

Helper class that provides a standard way to create an ABC using

SinusoidalPositionalEmbedding

Sinusoidal positional embeddings for transformer-style models.

Tabular

Implements a tabular policy.

TransformerDiscreteSequenceModel

Helper class that provides a standard way to create an ABC using

UniformModule

Constant-output module for non-learned (uniform/fixed) policies.

_AutoregressiveTransformerBlock

Functions

sinusoidal_position_encoding(length, embedding_dim[, base])

Create 1D sinusoidal positional embeddings.

Module Contents

gfn.utils.modules.ACTIVATION_FNS
class gfn.utils.modules.AutoregressiveDiscreteSequenceModel

Bases: abc.ABC, torch.nn.Module

Helper class that provides a standard way to create an ABC using inheritance.

abstract forward(x, carry)

Compute the logits for the next tokens in the sequence.

Parameters:
  • x (torch.Tensor) – (B, T) tensor of input token indices where T is the number of newly supplied timesteps (T may be 1 for incremental decoding).

  • carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).

Returns:

Logits for the next token

at each supplied timestep with shape (B, T, vocab) and updated carry.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

abstract init_carry(batch_size, device)

Initialize the carry for the sequence model.

Parameters:
  • batch_size (int) – Batch size.

  • device (torch.device) – Device to allocate carry tensors on.

Returns:

Initialized carry.

Return type:

dict[str, torch.Tensor]

property vocab_size: int
Abstractmethod:

Return type:

int

Size of the vocabulary (excluding BOS token).

class gfn.utils.modules.DiffusionFixedBackwardModule(s_dim)

Bases: torch.nn.Module

Fixed Backward Module for DiffusionPISGradNet.

Parameters:

s_dim (int)

input_dim

The dimension of the input.

forward(preprocessed_states)

Forward pass of the module.

Parameters:

preprocessed_states (torch.Tensor) – The preprocessed states (shape: (*batch_shape, s_dim + 1))

Returns:

(*batch_shape, s_dim)).

Return type:

The output of the module (shape

input_dim
class gfn.utils.modules.DiffusionPISGradNetBackward(s_dim, harmonics_dim=64, t_emb_dim=64, s_emb_dim=64, hidden_dim=64, joint_layers=2, zero_init=False, clipping=False, gfn_clip=10000.0, pb_scale_range=0.1, log_var_range=4.0, learn_variance=True)

Bases: torch.nn.Module

Learnable backward correction module (PIS-style) for diffusion.

Produces mean and optional log-std corrections that are tanh-scaled by pb_scale_range to stay close to the analytic Brownian bridge.

Parameters:
  • s_dim (int)

  • harmonics_dim (int)

  • t_emb_dim (int)

  • s_emb_dim (int)

  • hidden_dim (int)

  • joint_layers (int)

  • zero_init (bool)

  • clipping (bool)

  • gfn_clip (float)

  • pb_scale_range (float)

  • log_var_range (float)

  • learn_variance (bool)

clipping = False
forward(preprocessed_states)
Parameters:

preprocessed_states (torch.Tensor)

Return type:

torch.Tensor

gfn_clip = 10000.0
harmonics_dim = 64
hidden_dim = 64
joint_layers = 2
joint_model
learn_variance = True
log_var_range = 4.0
out_dim
pb_scale_range = 0.1
s_dim
s_emb_dim = 64
s_model
t_emb_dim = 64
t_model
zero_init = False
class gfn.utils.modules.DiffusionPISGradNetForward(s_dim, harmonics_dim=64, t_emb_dim=64, s_emb_dim=64, hidden_dim=64, joint_layers=2, zero_init=False, clipping=False, gfn_clip=10000.0, t_scale=1.0, log_var_range=4.0, learn_variance=False)

Bases: torch.nn.Module

PISGradNet for diffusion sampling.

This architecture was first introduced in Path Integral Sampler (PIS) (https://arxiv.org/abs/2111.15141) and adapted for GFlowNet-based training by Sendera et al. (https://arxiv.org/abs/2508.03044).

Parameters:
  • s_dim (int)

  • harmonics_dim (int)

  • t_emb_dim (int)

  • s_emb_dim (int)

  • hidden_dim (int)

  • joint_layers (int)

  • zero_init (bool)

  • clipping (bool)

  • gfn_clip (float)

  • t_scale (float)

  • log_var_range (float)

  • learn_variance (bool)

s_dim

The dimension of the states.

harmonics_dim

The dimension of the Fourier features.

t_emb_dim

The dimension of the time embedding.

s_emb_dim

The dimension of the state embedding.

hidden_dim

The dimension of the hidden layers.

joint_layers

The number of layers in the joint policy.

zero_init

Whether to initialize the weights and biases of the final layer to zero.

out_dim

The dimension of the output.

t_model

The time encoding module.

s_model

The state encoding module.

joint_model

The joint policy module.

clipping = False
forward(preprocessed_states)

Forward pass of the module.

Parameters:

preprocessed_states (torch.Tensor) – The preprocessed states (shape: (*batch_shape, s_dim + 1))

Returns:

(*batch_shape, s_dim)).

Return type:

The output of the module (shape

gfn_clip = 10000.0
harmonics_dim = 64
hidden_dim = 64
input_dim
joint_layers = 2
joint_model
learn_variance = False
log_var_range = 4.0
out_dim
s_dim
s_emb_dim = 64
s_model
t_emb_dim = 64
t_model
t_scale = 1.0
zero_init = False
class gfn.utils.modules.DiffusionPISJointPolicy(s_emb_dim, hidden_dim, out_dim, num_layers, zero_init=False)

Bases: torch.nn.Module

Joint Policy Module for DiffusionPISGradNet.

See DiffusionPISGradNet for more details.

Parameters:
  • s_emb_dim (int)

  • hidden_dim (int)

  • out_dim (int)

  • num_layers (int)

  • zero_init (bool)

forward(s_emb, t_emb)
Parameters:
  • s_emb (torch.Tensor)

  • t_emb (torch.Tensor)

Return type:

torch.Tensor

model
class gfn.utils.modules.DiffusionPISStateEncoding(x_dim, s_emb_dim)

Bases: torch.nn.Module

State Encoding Module for DiffusionPISGradNet.

See DiffusionPISGradNet for more details.

Parameters:
  • x_dim (int)

  • s_emb_dim (int)

forward(s)
Parameters:

s (torch.Tensor)

Return type:

torch.Tensor

s_model
class gfn.utils.modules.DiffusionPISTimeEncoding(harmonics_dim, t_emb_dim, hidden_dim)

Bases: torch.nn.Module

Time Encoding Module for DiffusionPISGradNet.

See DiffusionPISGradNet for more details.

Parameters:
  • harmonics_dim (int)

  • t_emb_dim (int)

  • hidden_dim (int)

forward(t)
Parameters:

t (torch.Tensor) – torch.Tensor

Return type:

torch.Tensor

t_model
timestep_phase
class gfn.utils.modules.DiscreteUniform(output_dim)

Bases: UniformModule

Uniform distribution over discrete actions.

Backward-compatible alias for UniformModule(output_dim, fill_value=0.0). Prefer UniformModule for new code.

Parameters:

output_dim (int)

class gfn.utils.modules.GraphActionGNN(num_node_classes, num_edge_classes, directed, embedding_dim=128, num_conv_layers=1, is_backward=False)

Bases: torch.nn.Module

Implements a GNN for graph action prediction.

Parameters:
  • num_node_classes (int)

  • num_edge_classes (int)

  • directed (bool)

  • embedding_dim (int)

  • num_conv_layers (int)

  • is_backward (bool)

static _group_mean(tensor, batch_ptr)
Parameters:
  • tensor (torch.Tensor)

  • batch_ptr (torch.Tensor)

Return type:

torch.Tensor

action_type_mlp
conv_blks
edge_class_mlp
embedding
forward(states_tensor)
Parameters:

states_tensor (gfn.utils.graphs.GeometricBatch)

Return type:

tensordict.TensorDict

property input_dim
is_backward = False
is_directed
node_class_mlp
node_index_mlp
norm
num_conv_layers = 1
num_edge_classes
num_node_classes
class gfn.utils.modules.GraphActionUniform(edges_dim, num_edge_classes, num_node_classes)

Bases: torch.nn.Module

Implements a uniform distribution over discrete actions given a graph state.

It uses a zero function approximator (a function that always outputs 0) to be used as logits by a DiscretePBEstimator.

Parameters:
  • edges_dim (int)

  • num_edge_classes (int)

  • num_node_classes (int)

output_dim

The size of the output space.

edges_dim
forward(states_tensor)

Forward method for the uniform distribution.

Parameters:

states_tensor (gfn.utils.graphs.GeometricBatch) – a batch of states appropriately preprocessed for ingestion by the uniform distribution.

Returns:

A TensorDict containing logits for each action component, with all values

set to 1 to represent a uniform distribution:

  • GraphActions.ACTION_TYPE_KEY: Tensor of shape [*batch_shape, 3] for the 3 possible action types

  • GraphActions.EDGE_CLASS_KEY: Tensor of shape [*batch_shape, num_edge_classes] for edge class logits

  • GraphActions.NODE_CLASS_KEY: Tensor of shape [*batch_shape, num_node_classes] for node class logits

  • GraphActions.EDGE_INDEX_KEY: Tensor of shape [*batch_shape, edges_dim] for edge index logits

Return type:

tensordict.TensorDict

input_dim = 1
num_edge_classes
num_node_classes
class gfn.utils.modules.GraphEdgeActionMLP(n_nodes, directed, num_node_classes, num_edge_classes, n_hidden_layers=2, n_hidden_layers_exit=1, embedding_dim=128, is_backward=False)

Bases: torch.nn.Module

Network that processes flattened adjacency matrices to predict graph actions.

Unlike the GNN-based GraphActionGNN, this module uses standard MLPs to process the entire adjacency matrix as a flattened vector. This approach:

  1. Can directly process global graph structure without message passing.

  2. May be more effective for small graphs where global patterns are important.

  3. Does not require complex graph neural network operations.

The module architecture consists of: - An MLP to process the flattened adjacency matrix into an embedding. - An edge MLP that predicts logits for each possible edge action. - An exit MLP that predicts a logit for the exit action.

Parameters:
  • n_nodes (int) – Number of nodes in the graph.

  • directed (bool) – Whether the graph is directed or undirected.

  • n_hidden_layers (int) – Number of hidden layers in the MLP for the edge actions.

  • n_hidden_layers_exit (int) – Number of hidden layers in the MLP for the exit action.

  • embedding_dim (int) – Dimension of internal embeddings.

  • is_backward (bool) – Whether this is a backward policy.

  • num_node_classes (int)

  • num_edge_classes (int)

_edges_dim
_input_dim
_output_dim
edge_class_mlp
edge_mlp
property edges_dim: int
Return type:

int

exit_mlp
features_embedding
forward(states_tensor)

Forward pass to compute action logits from graph states.

Process: 1. Convert the graph representation to adjacency matrices 2. Process the flattened adjacency matrices through the main MLP 3. Predict logits for edge actions and exit action

Parameters:

states_tensor (gfn.utils.graphs.GeometricBatch) – A GeometricBatch containing graph state information

Returns:

A tensor of logits for all possible actions

Return type:

tensordict.TensorDict

hidden_dim = 128
property input_dim: int
Return type:

int

is_backward = False
is_directed
mlp
n_nodes
node_class_mlp
node_index_mlp
num_edge_classes
num_node_classes
property output_dim: int
Return type:

int

class gfn.utils.modules.GraphScalarMLP(n_nodes, directed, embedding_dim=128, n_hidden_layers=2, n_outputs=1)

Bases: torch.nn.Module

Graph encoder that maps adjacency structure to n scalar output.

Parameters:
  • n_nodes (int)

  • directed (bool)

  • embedding_dim (int)

  • n_hidden_layers (int)

  • n_outputs (int)

backbone
forward(states_tensor)

Encode graphs into a scalar per element of the batch.

Parameters:

states_tensor (gfn.utils.graphs.GeometricBatch)

Return type:

torch.Tensor

head
input_dim
is_directed
n_nodes
class gfn.utils.modules.LinearTransformer(dim, depth, max_seq_len, n_heads=8, causal=False)

Bases: torch.nn.Module

The Linear Transformer module.

Implements Transformers are RNNs: Fast Autoregressive Transformers with Linear

Attention. Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret, ICML 2020.

Expresses self-attention as a linear dot-product of kernel feature maps and makes use the associativity property of matrix products to reduce the complexity of the attention computation from O(n^2) to O(n).

Implementation from https://github.com/lucidrains/linear-attention-transformer.

Parameters:
  • dim (int) – The dimension of the input.

  • depth (int) – The depth of the transformer.

  • max_seq_len (int) – The maximum sequence length.

  • n_heads (int) – The number of attention heads.

  • causal (bool) – Whether to use causal attention.

forward(x)
Parameters:

x (torch.Tensor)

Return type:

torch.Tensor

module
output_dim
class gfn.utils.modules.MLP(input_dim, output_dim, hidden_dim=256, n_hidden_layers=2, n_noisy_layers=0, activation_fn='relu', trunk=None, add_layer_norm=False, std_init=0.1)

Bases: torch.nn.Module

Implements a basic MLP with optional noisy layers for exploration.

When trunk is provided, the MLP will be a wrapper around the trunk.

See Noisy Networks for Exploration (Fortunato et al., ICLR 2018) for more details on the noisy layers.

Parameters:
  • input_dim (int)

  • output_dim (int)

  • hidden_dim (int)

  • n_hidden_layers (Optional[int])

  • n_noisy_layers (int)

  • activation_fn (Literal['relu', 'leaky_relu', 'tanh', 'elu'])

  • trunk (Optional[torch.nn.Module])

  • add_layer_norm (bool)

  • std_init (float)

_input_dim
_output_dim
forward(preprocessed_states)

Forward method for the neural network.

Parameters:

preprocessed_states (torch.Tensor) – a batch of states appropriately preprocessed for ingestion by the MLP. The shape of the tensor should be (*batch_shape, input_dim).

Return type:

torch.Tensor

Returns: a tensor of shape (*batch_shape, output_dim).

property input_dim
property output_dim
class gfn.utils.modules.NoisyLinear(in_features, out_features, bias=True, device=None, dtype=None, std_init=0.1)

Bases: torch.nn.Linear

Noisy Linear Layer.

Presented in “Noisy Networks for Exploration”, https://arxiv.org/abs/1706.10295v3

A Noisy Linear Layer is a linear layer with parametric noise added to the weights. This induced stochasticity can be used in RL networks for the agent’s policy to aid efficient exploration. The parameters of the noise are learned with gradient descent along with any other remaining network weights. Factorized Gaussian noise is the type of noise usually employed.

Taken from torchrl v0.9.2.

Parameters:
  • in_features (int) – input features dimension

  • out_features (int) – out features dimension

  • bias (bool, optional) – if True, a bias term will be added to the matrix multiplication: Ax + b. Defaults to True

  • device (DEVICE_TYPING, optional) – device of the layer. Defaults to "cpu"

  • dtype (torch.dtype, optional) – dtype of the parameters. Defaults to None (default pytorch dtype)

  • std_init (scalar, optional) – initial value of the Gaussian standard deviation before optimization. Defaults to 0.1

_scale_noise(size)
Parameters:

size (int | torch.Size)

Return type:

torch.Tensor

property bias: torch.Tensor | None
Return type:

torch.Tensor | None

in_features
out_features
reset_noise()
Return type:

None

reset_parameters()
Return type:

None

std_init = 0.1
property weight: torch.Tensor
Return type:

torch.Tensor

weight_mu
weight_sigma
class gfn.utils.modules.RecurrentDiscreteSequenceModel(vocab_size, embedding_dim, hidden_size, num_layers=1, rnn_type='lstm', dropout=0.0)

Bases: AutoregressiveDiscreteSequenceModel

Helper class that provides a standard way to create an ABC using inheritance.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • hidden_size (int)

  • num_layers (int)

  • rnn_type (Literal['lstm', 'gru'])

  • dropout (float)

_vocab_size
embedding
embedding_dim
forward(x, carry)

Compute the logits for the next tokens in the sequence.

Parameters:
  • x (torch.Tensor) – (B, T) tensor of input token indices where T is the number of newly supplied timesteps (T may be 1 for incremental decoding).

  • carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).

Returns:

Logits for the next token

at each supplied timestep with shape (B, T, vocab) and updated carry.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

gru: torch.nn.GRU | None
hidden_size
init_carry(batch_size, device)

Initialize the carry for the sequence model.

Parameters:
  • batch_size (int) – Batch size.

  • device (torch.device) – Device to allocate carry tensors on.

Returns:

Initialized carry.

Return type:

dict[str, torch.Tensor]

lstm: torch.nn.LSTM | None
num_layers = 1
output_projection
rnn_type = ''
property vocab_size: int

Size of the vocabulary (excluding BOS token).

Return type:

int

class gfn.utils.modules.SinusoidalPositionalEmbedding(embedding_dim, max_length=2048, base=10000.0)

Bases: torch.nn.Module

Sinusoidal positional embeddings for transformer-style models.

The module caches a precomputed table of embeddings and extends it on demand. Forward accepts either a sequence length or explicit position indices.

Parameters:
  • embedding_dim (int)

  • max_length (int)

  • base (float)

_pe: torch.Tensor
base
embedding_dim
forward(positions=None, seq_len=None)

Look up positional embeddings.

Parameters:
  • positions (Optional[torch.Tensor]) – Optional tensor of position indices. Can have any shape, and the returned embeddings will append embedding_dim to that shape. Defaults to None.

  • seq_len (Optional[int]) – Optional sequence length. When provided, returns the first seq_len embeddings from the table.

Returns:

Tensor of positional embeddings on the same device/dtype as the cached table.

Raises:

ValueError – If both or neither of positions and seq_len are provided, or if indices exceed the cached range.

Return type:

torch.Tensor

property pe: torch.Tensor

Return the cached positional embedding table.

Return type:

torch.Tensor

class gfn.utils.modules.Tabular(n_states, output_dim)

Bases: torch.nn.Module

Implements a tabular policy.

This class is only compatible with the EnumPreprocessor.

Parameters:
  • n_states (int)

  • output_dim (int)

table

a tensor with dimensions [n_states, output_dim].

device

the device that holds this policy.

device = None
forward(preprocessed_states)

Forward method for the tabular policy.

Parameters:

preprocessed_states (torch.Tensor) – a batch of states appropriately preprocessed for ingestion by the tabular policy. The shape of the tensor should be (*batch_shape, 1).

Return type:

torch.Tensor

Returns: a tensor of shape (*batch_shape, output_dim).

table
class gfn.utils.modules.TransformerDiscreteSequenceModel(vocab_size, embedding_dim, num_heads, ff_hidden_dim, num_layers, max_position_embeddings, dropout=0.0, positional_embedding='learned')

Bases: AutoregressiveDiscreteSequenceModel

Helper class that provides a standard way to create an ABC using inheritance.

Parameters:
  • vocab_size (int)

  • embedding_dim (int)

  • num_heads (int)

  • ff_hidden_dim (int)

  • num_layers (int)

  • max_position_embeddings (int)

  • dropout (float)

  • positional_embedding (Literal['learned', 'sinusoidal'])

_positional_embedding_type = 'learned'
_vocab_size
embedding_dim
embedding_dropout
ff_hidden_dim
final_norm
forward(x, carry)

Compute the logits for the next tokens in the sequence.

Parameters:
  • x (torch.Tensor) – (B, T) tensor of input token indices where T is the number of newly supplied timesteps (T may be 1 for incremental decoding).

  • carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).

Returns:

Logits for the next token

at each supplied timestep with shape (B, T, vocab) and updated carry.

Return type:

tuple[torch.Tensor, dict[str, torch.Tensor]]

head_dim
init_carry(batch_size, device)

Initialize the carry for the sequence model.

Parameters:
  • batch_size (int) – Batch size.

  • device (torch.device) – Device to allocate carry tensors on.

Returns:

Initialized carry.

Return type:

dict[str, torch.Tensor]

key_names
layers
max_position_embeddings
num_heads
num_layers
output_projection
token_embedding
value_names
property vocab_size: int

Size of the vocabulary (excluding BOS token).

Return type:

int

class gfn.utils.modules.UniformModule(output_dim, input_dim=None, fill_value=0.0, skip_normalization=False)

Bases: torch.nn.Module

Constant-output module for non-learned (uniform/fixed) policies.

Outputs a constant tensor for all inputs. Plug into any Estimator as the module argument to get a non-learned policy.

Typical fill values:

  • 0.0 — equal logits → uniform categorical (discrete environments).

  • 1.0 — fixed params for sigmoid-based continuous estimators.

Parameters:
  • output_dim (int)

  • input_dim (int | None)

  • fill_value (float)

  • skip_normalization (bool)

output_dim

Output dimension matching the estimator’s expected_output_dim.

fill_value

Constant value to fill the output tensor.

skip_normalization

If True, signals to estimators that outputs should be used directly without normalization transforms (e.g. sigmoid scaling of concentration parameters).

fill_value = 0.0
forward(preprocessed_states)

Return a constant tensor of shape (*batch_shape, output_dim).

Parameters:

preprocessed_states (torch.Tensor) – Tensor of shape (*batch_shape, input_dim).

Returns:

Tensor of shape (*batch_shape, output_dim) filled with fill_value.

Return type:

torch.Tensor

output_dim
skip_normalization = False
class gfn.utils.modules._AutoregressiveTransformerBlock(embed_dim, num_heads, ff_hidden_dim, dropout)

Bases: torch.nn.Module

Parameters:
  • embed_dim (int)

  • num_heads (int)

  • ff_hidden_dim (int)

  • dropout (float)

attn_dropout
embed_dim
ff_dropout
forward(hidden, key_carry, value_carry)
Parameters:
  • hidden (torch.Tensor)

  • key_carry (torch.Tensor)

  • value_carry (torch.Tensor)

Return type:

tuple[torch.Tensor, torch.Tensor, torch.Tensor]

head_dim
k_proj
linear1
linear2
norm1
norm2
num_heads
out_proj
q_proj
residual_dropout
v_proj
gfn.utils.modules.logger
gfn.utils.modules.sinusoidal_position_encoding(length, embedding_dim, base=10000.0)

Create 1D sinusoidal positional embeddings.

Parameters:
  • length (int) – Number of positions to encode. Must be non-negative.

  • embedding_dim (int) – Dimensionality of each embedding. Must be positive.

  • base (float) – Exponential base used to compute the angular frequencies.

Returns:

A (length, embedding_dim) tensor of sinusoidal encodings.

Raises:

ValueError – If length is negative, embedding_dim is not positive, or base is not positive.

Return type:

torch.Tensor