gfn.utils.modules¶
This file contains some examples of modules that can be used with GFN.
Attributes¶
Classes¶
Helper class that provides a standard way to create an ABC using |
|
Fixed Backward Module for DiffusionPISGradNet. |
|
Learnable backward correction module (PIS-style) for diffusion. |
|
PISGradNet for diffusion sampling. |
|
Joint Policy Module for DiffusionPISGradNet. |
|
State Encoding Module for DiffusionPISGradNet. |
|
Time Encoding Module for DiffusionPISGradNet. |
|
Uniform distribution over discrete actions. |
|
Implements a GNN for graph action prediction. |
|
Implements a uniform distribution over discrete actions given a graph state. |
|
Network that processes flattened adjacency matrices to predict graph actions. |
|
Graph encoder that maps adjacency structure to n scalar output. |
|
The Linear Transformer module. |
|
Implements a basic MLP with optional noisy layers for exploration. |
|
Noisy Linear Layer. |
|
Helper class that provides a standard way to create an ABC using |
|
Sinusoidal positional embeddings for transformer-style models. |
|
Implements a tabular policy. |
|
Helper class that provides a standard way to create an ABC using |
|
Constant-output module for non-learned (uniform/fixed) policies. |
|
Functions¶
|
Create 1D sinusoidal positional embeddings. |
Module Contents¶
- gfn.utils.modules.ACTIVATION_FNS¶
- class gfn.utils.modules.AutoregressiveDiscreteSequenceModel¶
Bases:
abc.ABC,torch.nn.ModuleHelper class that provides a standard way to create an ABC using inheritance.
- abstract forward(x, carry)¶
Compute the logits for the next tokens in the sequence.
- Parameters:
x (torch.Tensor) – (B, T) tensor of input token indices where
Tis the number of newly supplied timesteps (Tmay be 1 for incremental decoding).carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).
- Returns:
- Logits for the next token
at each supplied timestep with shape (B, T, vocab) and updated carry.
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor]]
- abstract init_carry(batch_size, device)¶
Initialize the carry for the sequence model.
- Parameters:
batch_size (int) – Batch size.
device (torch.device) – Device to allocate carry tensors on.
- Returns:
Initialized carry.
- Return type:
dict[str, torch.Tensor]
- property vocab_size: int¶
- Abstractmethod:
- Return type:
int
Size of the vocabulary (excluding BOS token).
- class gfn.utils.modules.DiffusionFixedBackwardModule(s_dim)¶
Bases:
torch.nn.ModuleFixed Backward Module for DiffusionPISGradNet.
- Parameters:
s_dim (int)
- input_dim¶
The dimension of the input.
- forward(preprocessed_states)¶
Forward pass of the module.
- input_dim¶
- class gfn.utils.modules.DiffusionPISGradNetBackward(s_dim, harmonics_dim=64, t_emb_dim=64, s_emb_dim=64, hidden_dim=64, joint_layers=2, zero_init=False, clipping=False, gfn_clip=10000.0, pb_scale_range=0.1, log_var_range=4.0, learn_variance=True)¶
Bases:
torch.nn.ModuleLearnable backward correction module (PIS-style) for diffusion.
Produces mean and optional log-std corrections that are tanh-scaled by pb_scale_range to stay close to the analytic Brownian bridge.
- Parameters:
s_dim (int)
harmonics_dim (int)
t_emb_dim (int)
s_emb_dim (int)
hidden_dim (int)
joint_layers (int)
zero_init (bool)
clipping (bool)
gfn_clip (float)
pb_scale_range (float)
log_var_range (float)
learn_variance (bool)
- clipping = False¶
- forward(preprocessed_states)¶
- Parameters:
preprocessed_states (torch.Tensor)
- Return type:
torch.Tensor
- gfn_clip = 10000.0¶
- harmonics_dim = 64¶
- joint_layers = 2¶
- joint_model¶
- learn_variance = True¶
- log_var_range = 4.0¶
- out_dim¶
- pb_scale_range = 0.1¶
- s_dim¶
- s_emb_dim = 64¶
- s_model¶
- t_emb_dim = 64¶
- t_model¶
- zero_init = False¶
- class gfn.utils.modules.DiffusionPISGradNetForward(s_dim, harmonics_dim=64, t_emb_dim=64, s_emb_dim=64, hidden_dim=64, joint_layers=2, zero_init=False, clipping=False, gfn_clip=10000.0, t_scale=1.0, log_var_range=4.0, learn_variance=False)¶
Bases:
torch.nn.ModulePISGradNet for diffusion sampling.
This architecture was first introduced in Path Integral Sampler (PIS) (https://arxiv.org/abs/2111.15141) and adapted for GFlowNet-based training by Sendera et al. (https://arxiv.org/abs/2508.03044).
- Parameters:
s_dim (int)
harmonics_dim (int)
t_emb_dim (int)
s_emb_dim (int)
hidden_dim (int)
joint_layers (int)
zero_init (bool)
clipping (bool)
gfn_clip (float)
t_scale (float)
log_var_range (float)
learn_variance (bool)
- s_dim¶
The dimension of the states.
- harmonics_dim¶
The dimension of the Fourier features.
- t_emb_dim¶
The dimension of the time embedding.
- s_emb_dim¶
The dimension of the state embedding.
The dimension of the hidden layers.
- joint_layers¶
The number of layers in the joint policy.
- zero_init¶
Whether to initialize the weights and biases of the final layer to zero.
- out_dim¶
The dimension of the output.
- t_model¶
The time encoding module.
- s_model¶
The state encoding module.
- joint_model¶
The joint policy module.
- clipping = False¶
- forward(preprocessed_states)¶
Forward pass of the module.
- gfn_clip = 10000.0¶
- harmonics_dim = 64¶
- hidden_dim = 64¶
- input_dim¶
- joint_layers = 2¶
- joint_model¶
- learn_variance = False¶
- log_var_range = 4.0¶
- out_dim¶
- s_dim¶
- s_emb_dim = 64¶
- s_model¶
- t_emb_dim = 64¶
- t_model¶
- t_scale = 1.0¶
- zero_init = False¶
- class gfn.utils.modules.DiffusionPISJointPolicy(s_emb_dim, hidden_dim, out_dim, num_layers, zero_init=False)¶
Bases:
torch.nn.ModuleJoint Policy Module for DiffusionPISGradNet.
See DiffusionPISGradNet for more details.
- Parameters:
s_emb_dim (int)
hidden_dim (int)
out_dim (int)
num_layers (int)
zero_init (bool)
- forward(s_emb, t_emb)¶
- Parameters:
s_emb (torch.Tensor)
t_emb (torch.Tensor)
- Return type:
torch.Tensor
- model¶
- class gfn.utils.modules.DiffusionPISStateEncoding(x_dim, s_emb_dim)¶
Bases:
torch.nn.ModuleState Encoding Module for DiffusionPISGradNet.
See DiffusionPISGradNet for more details.
- Parameters:
x_dim (int)
s_emb_dim (int)
- forward(s)¶
- Parameters:
s (torch.Tensor)
- Return type:
torch.Tensor
- s_model¶
- class gfn.utils.modules.DiffusionPISTimeEncoding(harmonics_dim, t_emb_dim, hidden_dim)¶
Bases:
torch.nn.ModuleTime Encoding Module for DiffusionPISGradNet.
See DiffusionPISGradNet for more details.
- Parameters:
harmonics_dim (int)
t_emb_dim (int)
hidden_dim (int)
- forward(t)¶
- Parameters:
t (torch.Tensor) – torch.Tensor
- Return type:
torch.Tensor
- t_model¶
- timestep_phase¶
- class gfn.utils.modules.DiscreteUniform(output_dim)¶
Bases:
UniformModuleUniform distribution over discrete actions.
Backward-compatible alias for
UniformModule(output_dim, fill_value=0.0). PreferUniformModulefor new code.- Parameters:
output_dim (int)
- class gfn.utils.modules.GraphActionGNN(num_node_classes, num_edge_classes, directed, embedding_dim=128, num_conv_layers=1, is_backward=False)¶
Bases:
torch.nn.ModuleImplements a GNN for graph action prediction.
- Parameters:
num_node_classes (int)
num_edge_classes (int)
directed (bool)
embedding_dim (int)
num_conv_layers (int)
is_backward (bool)
- static _group_mean(tensor, batch_ptr)¶
- Parameters:
tensor (torch.Tensor)
batch_ptr (torch.Tensor)
- Return type:
torch.Tensor
- action_type_mlp¶
- conv_blks¶
- edge_class_mlp¶
- embedding¶
- forward(states_tensor)¶
- Parameters:
states_tensor (gfn.utils.graphs.GeometricBatch)
- Return type:
tensordict.TensorDict
- property input_dim¶
- is_backward = False¶
- is_directed¶
- node_class_mlp¶
- node_index_mlp¶
- norm¶
- num_conv_layers = 1¶
- num_edge_classes¶
- num_node_classes¶
- class gfn.utils.modules.GraphActionUniform(edges_dim, num_edge_classes, num_node_classes)¶
Bases:
torch.nn.ModuleImplements a uniform distribution over discrete actions given a graph state.
It uses a zero function approximator (a function that always outputs 0) to be used as logits by a DiscretePBEstimator.
- Parameters:
edges_dim (int)
num_edge_classes (int)
num_node_classes (int)
- output_dim¶
The size of the output space.
- edges_dim¶
- forward(states_tensor)¶
Forward method for the uniform distribution.
- Parameters:
states_tensor (gfn.utils.graphs.GeometricBatch) – a batch of states appropriately preprocessed for ingestion by the uniform distribution.
- Returns:
- A TensorDict containing logits for each action component, with all values
set to 1 to represent a uniform distribution:
GraphActions.ACTION_TYPE_KEY: Tensor of shape [*batch_shape, 3] for the 3 possible action types
GraphActions.EDGE_CLASS_KEY: Tensor of shape [*batch_shape, num_edge_classes] for edge class logits
GraphActions.NODE_CLASS_KEY: Tensor of shape [*batch_shape, num_node_classes] for node class logits
GraphActions.EDGE_INDEX_KEY: Tensor of shape [*batch_shape, edges_dim] for edge index logits
- Return type:
tensordict.TensorDict
- input_dim = 1¶
- num_edge_classes¶
- num_node_classes¶
- class gfn.utils.modules.GraphEdgeActionMLP(n_nodes, directed, num_node_classes, num_edge_classes, n_hidden_layers=2, n_hidden_layers_exit=1, embedding_dim=128, is_backward=False)¶
Bases:
torch.nn.ModuleNetwork that processes flattened adjacency matrices to predict graph actions.
Unlike the GNN-based GraphActionGNN, this module uses standard MLPs to process the entire adjacency matrix as a flattened vector. This approach:
Can directly process global graph structure without message passing.
May be more effective for small graphs where global patterns are important.
Does not require complex graph neural network operations.
The module architecture consists of: - An MLP to process the flattened adjacency matrix into an embedding. - An edge MLP that predicts logits for each possible edge action. - An exit MLP that predicts a logit for the exit action.
- Parameters:
n_nodes (int) – Number of nodes in the graph.
directed (bool) – Whether the graph is directed or undirected.
n_hidden_layers (int) – Number of hidden layers in the MLP for the edge actions.
n_hidden_layers_exit (int) – Number of hidden layers in the MLP for the exit action.
embedding_dim (int) – Dimension of internal embeddings.
is_backward (bool) – Whether this is a backward policy.
num_node_classes (int)
num_edge_classes (int)
- _edges_dim¶
- _input_dim¶
- _output_dim¶
- edge_class_mlp¶
- edge_mlp¶
- property edges_dim: int¶
- Return type:
int
- exit_mlp¶
- features_embedding¶
- forward(states_tensor)¶
Forward pass to compute action logits from graph states.
Process: 1. Convert the graph representation to adjacency matrices 2. Process the flattened adjacency matrices through the main MLP 3. Predict logits for edge actions and exit action
- Parameters:
states_tensor (gfn.utils.graphs.GeometricBatch) – A GeometricBatch containing graph state information
- Returns:
A tensor of logits for all possible actions
- Return type:
tensordict.TensorDict
- property input_dim: int¶
- Return type:
int
- is_backward = False¶
- is_directed¶
- mlp¶
- n_nodes¶
- node_class_mlp¶
- node_index_mlp¶
- num_edge_classes¶
- num_node_classes¶
- property output_dim: int¶
- Return type:
int
- class gfn.utils.modules.GraphScalarMLP(n_nodes, directed, embedding_dim=128, n_hidden_layers=2, n_outputs=1)¶
Bases:
torch.nn.ModuleGraph encoder that maps adjacency structure to n scalar output.
- Parameters:
n_nodes (int)
directed (bool)
embedding_dim (int)
n_hidden_layers (int)
n_outputs (int)
- backbone¶
- forward(states_tensor)¶
Encode graphs into a scalar per element of the batch.
- Parameters:
states_tensor (gfn.utils.graphs.GeometricBatch)
- Return type:
torch.Tensor
- head¶
- input_dim¶
- is_directed¶
- n_nodes¶
- class gfn.utils.modules.LinearTransformer(dim, depth, max_seq_len, n_heads=8, causal=False)¶
Bases:
torch.nn.ModuleThe Linear Transformer module.
- Implements Transformers are RNNs: Fast Autoregressive Transformers with Linear
Attention. Angelos Katharopoulos, Apoorv Vyas, Nikolaos Pappas, François Fleuret, ICML 2020.
Expresses self-attention as a linear dot-product of kernel feature maps and makes use the associativity property of matrix products to reduce the complexity of the attention computation from O(n^2) to O(n).
Implementation from https://github.com/lucidrains/linear-attention-transformer.
- Parameters:
dim (int) – The dimension of the input.
depth (int) – The depth of the transformer.
max_seq_len (int) – The maximum sequence length.
n_heads (int) – The number of attention heads.
causal (bool) – Whether to use causal attention.
- forward(x)¶
- Parameters:
x (torch.Tensor)
- Return type:
torch.Tensor
- module¶
- output_dim¶
- class gfn.utils.modules.MLP(input_dim, output_dim, hidden_dim=256, n_hidden_layers=2, n_noisy_layers=0, activation_fn='relu', trunk=None, add_layer_norm=False, std_init=0.1)¶
Bases:
torch.nn.ModuleImplements a basic MLP with optional noisy layers for exploration.
When trunk is provided, the MLP will be a wrapper around the trunk.
See Noisy Networks for Exploration (Fortunato et al., ICLR 2018) for more details on the noisy layers.
- Parameters:
input_dim (int)
output_dim (int)
hidden_dim (int)
n_hidden_layers (Optional[int])
n_noisy_layers (int)
activation_fn (Literal['relu', 'leaky_relu', 'tanh', 'elu'])
trunk (Optional[torch.nn.Module])
add_layer_norm (bool)
std_init (float)
- _input_dim¶
- _output_dim¶
- forward(preprocessed_states)¶
Forward method for the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – a batch of states appropriately preprocessed for ingestion by the MLP. The shape of the tensor should be (*batch_shape, input_dim).
- Return type:
torch.Tensor
Returns: a tensor of shape (*batch_shape, output_dim).
- property input_dim¶
- property output_dim¶
- class gfn.utils.modules.NoisyLinear(in_features, out_features, bias=True, device=None, dtype=None, std_init=0.1)¶
Bases:
torch.nn.LinearNoisy Linear Layer.
Presented in “Noisy Networks for Exploration”, https://arxiv.org/abs/1706.10295v3
A Noisy Linear Layer is a linear layer with parametric noise added to the weights. This induced stochasticity can be used in RL networks for the agent’s policy to aid efficient exploration. The parameters of the noise are learned with gradient descent along with any other remaining network weights. Factorized Gaussian noise is the type of noise usually employed.
Taken from torchrl v0.9.2.
- Parameters:
in_features (int) – input features dimension
out_features (int) – out features dimension
bias (bool, optional) – if
True, a bias term will be added to the matrix multiplication: Ax + b. Defaults toTruedevice (DEVICE_TYPING, optional) – device of the layer. Defaults to
"cpu"dtype (torch.dtype, optional) – dtype of the parameters. Defaults to
None(default pytorch dtype)std_init (scalar, optional) – initial value of the Gaussian standard deviation before optimization. Defaults to
0.1
- _scale_noise(size)¶
- Parameters:
size (int | torch.Size)
- Return type:
torch.Tensor
- property bias: torch.Tensor | None¶
- Return type:
torch.Tensor | None
- in_features¶
- out_features¶
- reset_noise()¶
- Return type:
None
- reset_parameters()¶
- Return type:
None
- std_init = 0.1¶
- property weight: torch.Tensor¶
- Return type:
torch.Tensor
- weight_mu¶
- weight_sigma¶
- class gfn.utils.modules.RecurrentDiscreteSequenceModel(vocab_size, embedding_dim, hidden_size, num_layers=1, rnn_type='lstm', dropout=0.0)¶
Bases:
AutoregressiveDiscreteSequenceModelHelper class that provides a standard way to create an ABC using inheritance.
- Parameters:
vocab_size (int)
embedding_dim (int)
hidden_size (int)
num_layers (int)
rnn_type (Literal['lstm', 'gru'])
dropout (float)
- _vocab_size¶
- embedding¶
- embedding_dim¶
- forward(x, carry)¶
Compute the logits for the next tokens in the sequence.
- Parameters:
x (torch.Tensor) – (B, T) tensor of input token indices where
Tis the number of newly supplied timesteps (Tmay be 1 for incremental decoding).carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).
- Returns:
- Logits for the next token
at each supplied timestep with shape (B, T, vocab) and updated carry.
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor]]
- gru: torch.nn.GRU | None¶
- init_carry(batch_size, device)¶
Initialize the carry for the sequence model.
- Parameters:
batch_size (int) – Batch size.
device (torch.device) – Device to allocate carry tensors on.
- Returns:
Initialized carry.
- Return type:
dict[str, torch.Tensor]
- lstm: torch.nn.LSTM | None¶
- num_layers = 1¶
- output_projection¶
- rnn_type = ''¶
- property vocab_size: int¶
Size of the vocabulary (excluding BOS token).
- Return type:
int
- class gfn.utils.modules.SinusoidalPositionalEmbedding(embedding_dim, max_length=2048, base=10000.0)¶
Bases:
torch.nn.ModuleSinusoidal positional embeddings for transformer-style models.
The module caches a precomputed table of embeddings and extends it on demand. Forward accepts either a sequence length or explicit position indices.
- Parameters:
embedding_dim (int)
max_length (int)
base (float)
- _pe: torch.Tensor¶
- base¶
- embedding_dim¶
- forward(positions=None, seq_len=None)¶
Look up positional embeddings.
- Parameters:
positions (Optional[torch.Tensor]) – Optional tensor of position indices. Can have any shape, and the returned embeddings will append
embedding_dimto that shape. Defaults toNone.seq_len (Optional[int]) – Optional sequence length. When provided, returns the first
seq_lenembeddings from the table.
- Returns:
Tensor of positional embeddings on the same device/dtype as the cached table.
- Raises:
ValueError – If both or neither of
positionsandseq_lenare provided, or if indices exceed the cached range.- Return type:
torch.Tensor
- property pe: torch.Tensor¶
Return the cached positional embedding table.
- Return type:
torch.Tensor
- class gfn.utils.modules.Tabular(n_states, output_dim)¶
Bases:
torch.nn.ModuleImplements a tabular policy.
This class is only compatible with the EnumPreprocessor.
- Parameters:
n_states (int)
output_dim (int)
- table¶
a tensor with dimensions [n_states, output_dim].
- device¶
the device that holds this policy.
- device = None¶
- forward(preprocessed_states)¶
Forward method for the tabular policy.
- Parameters:
preprocessed_states (torch.Tensor) – a batch of states appropriately preprocessed for ingestion by the tabular policy. The shape of the tensor should be (*batch_shape, 1).
- Return type:
torch.Tensor
Returns: a tensor of shape (*batch_shape, output_dim).
- table¶
- class gfn.utils.modules.TransformerDiscreteSequenceModel(vocab_size, embedding_dim, num_heads, ff_hidden_dim, num_layers, max_position_embeddings, dropout=0.0, positional_embedding='learned')¶
Bases:
AutoregressiveDiscreteSequenceModelHelper class that provides a standard way to create an ABC using inheritance.
- Parameters:
vocab_size (int)
embedding_dim (int)
num_heads (int)
ff_hidden_dim (int)
num_layers (int)
max_position_embeddings (int)
dropout (float)
positional_embedding (Literal['learned', 'sinusoidal'])
- _positional_embedding_type = 'learned'¶
- _vocab_size¶
- embedding_dim¶
- embedding_dropout¶
- final_norm¶
- forward(x, carry)¶
Compute the logits for the next tokens in the sequence.
- Parameters:
x (torch.Tensor) – (B, T) tensor of input token indices where
Tis the number of newly supplied timesteps (Tmay be 1 for incremental decoding).carry (dict[str, torch.Tensor]) – Carry from previous steps for recurrent processing (e.g., hidden states).
- Returns:
- Logits for the next token
at each supplied timestep with shape (B, T, vocab) and updated carry.
- Return type:
tuple[torch.Tensor, dict[str, torch.Tensor]]
- head_dim¶
- init_carry(batch_size, device)¶
Initialize the carry for the sequence model.
- Parameters:
batch_size (int) – Batch size.
device (torch.device) – Device to allocate carry tensors on.
- Returns:
Initialized carry.
- Return type:
dict[str, torch.Tensor]
- key_names¶
- layers¶
- max_position_embeddings¶
- num_heads¶
- num_layers¶
- output_projection¶
- token_embedding¶
- value_names¶
- property vocab_size: int¶
Size of the vocabulary (excluding BOS token).
- Return type:
int
- class gfn.utils.modules.UniformModule(output_dim, input_dim=None, fill_value=0.0, skip_normalization=False)¶
Bases:
torch.nn.ModuleConstant-output module for non-learned (uniform/fixed) policies.
Outputs a constant tensor for all inputs. Plug into any
Estimatoras themoduleargument to get a non-learned policy.Typical fill values:
0.0— equal logits → uniform categorical (discrete environments).1.0— fixed params for sigmoid-based continuous estimators.
- Parameters:
output_dim (int)
input_dim (int | None)
fill_value (float)
skip_normalization (bool)
- output_dim¶
Output dimension matching the estimator’s expected_output_dim.
- fill_value¶
Constant value to fill the output tensor.
- skip_normalization¶
If True, signals to estimators that outputs should be used directly without normalization transforms (e.g. sigmoid scaling of concentration parameters).
- fill_value = 0.0¶
- forward(preprocessed_states)¶
Return a constant tensor of shape
(*batch_shape, output_dim).- Parameters:
preprocessed_states (torch.Tensor) – Tensor of shape
(*batch_shape, input_dim).- Returns:
Tensor of shape
(*batch_shape, output_dim)filled withfill_value.- Return type:
torch.Tensor
- output_dim¶
- skip_normalization = False¶
- class gfn.utils.modules._AutoregressiveTransformerBlock(embed_dim, num_heads, ff_hidden_dim, dropout)¶
Bases:
torch.nn.Module- Parameters:
embed_dim (int)
num_heads (int)
ff_hidden_dim (int)
dropout (float)
- attn_dropout¶
- embed_dim¶
- ff_dropout¶
- forward(hidden, key_carry, value_carry)¶
- Parameters:
hidden (torch.Tensor)
key_carry (torch.Tensor)
value_carry (torch.Tensor)
- Return type:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- head_dim¶
- k_proj¶
- linear1¶
- linear2¶
- norm1¶
- norm2¶
- num_heads¶
- out_proj¶
- q_proj¶
- residual_dropout¶
- v_proj¶
- gfn.utils.modules.logger¶
- gfn.utils.modules.sinusoidal_position_encoding(length, embedding_dim, base=10000.0)¶
Create 1D sinusoidal positional embeddings.
- Parameters:
length (int) – Number of positions to encode. Must be non-negative.
embedding_dim (int) – Dimensionality of each embedding. Must be positive.
base (float) – Exponential base used to compute the angular frequencies.
- Returns:
A
(length, embedding_dim)tensor of sinusoidal encodings.- Raises:
ValueError – If
lengthis negative,embedding_dimis not positive, orbaseis not positive.- Return type:
torch.Tensor