gfn.preprocessors

Classes

EnumPreprocessor

Preprocessor for environments with discrete, enumerable states.

IdentityPreprocessor

Simple preprocessor that returns states without modification.

KHotPreprocessor

Preprocessor for grid-structured discrete states with multi-dimensional encoding.

OneHotPreprocessor

Preprocessor that converts discrete states to one-hot encoded vectors.

Preprocessor

Base class for state preprocessors.

Module Contents

class gfn.preprocessors.EnumPreprocessor(get_states_indices)

Bases: Preprocessor

Preprocessor for environments with discrete, enumerable states.

This preprocessor converts discrete states to their unique integer indices, making them suitable for neural network processing. It is designed for environments with a finite number of states where each state can be uniquely identified by an index.

Parameters:

get_states_indices (Callable[[gfn.states.DiscreteStates], torch.Tensor])

output_dim

Always 1, as states are represented by single indices.

get_states_indices

Function that returns unique indices for states.

get_states_indices
preprocess(states)

Preprocesses the states by returning their unique indices.

Parameters:

states (gfn.states.DiscreteStates) – The discrete states to preprocess.

Returns:

A tensor of shape (*batch_shape, 1) containing the unique indices of the states.

Return type:

torch.Tensor

class gfn.preprocessors.IdentityPreprocessor(output_dim, target_dtype=None)

Bases: Preprocessor

Simple preprocessor that returns states without modification.

This preprocessor serves as the default preprocessor. It can handle both graph and tensor-based states by returning them as-is.

Parameters:
  • output_dim (int | None)

  • target_dtype (torch.dtype | None)

output_dim

The dimensionality of the input states.

preprocess(states)

Returns the states without any preprocessing.

Parameters:

states (gfn.states.States) – The states to preprocess.

Returns:

Tensor or GeometricBatch representing the states.

Return type:

torch.Tensor | gfn.utils.graphs.GeometricBatch

class gfn.preprocessors.KHotPreprocessor(height, ndim)

Bases: Preprocessor

Preprocessor for grid-structured discrete states with multi-dimensional encoding.

This preprocessor is designed for environments with grid-like state spaces where each dimension can take on a finite number of values. It creates a k-hot encoding where each dimension is one-hot encoded and then concatenated.

Parameters:
  • height (int)

  • ndim (int)

output_dim

The total output dimension (height * ndim).

height

Number of unique values per dimension.

ndim

Number of dimensions in the state space.

height
ndim
output_dim: int
preprocess(states)

Preprocesses the states by creating k-hot encoded vectors.

Each dimension of the state is one-hot encoded and then concatenated into a single vector.

Parameters:

states (gfn.states.DiscreteStates) – The discrete states to preprocess.

Returns:

A tensor of shape (*batch_shape, height * ndim) containing k-hot encoded states.

Return type:

torch.Tensor

Note

This preprocessor only works for integer state tensors.

class gfn.preprocessors.OneHotPreprocessor(n_states, get_states_indices)

Bases: Preprocessor

Preprocessor that converts discrete states to one-hot encoded vectors.

This preprocessor is designed for environments with enumerable states where each state is represented as a one-hot vector. The output dimension equals the total number of possible states.

Parameters:
output_dim

The total number of states in the environment.

get_states_indices

Function that returns unique indices for states.

get_states_indices
output_dim: int
preprocess(states)

Preprocesses the states by converting them to one-hot encoded vectors.

Parameters:

states (gfn.states.DiscreteStates) – The discrete states to preprocess.

Returns:

A tensor of shape (*batch_shape, n_states) containing one-hot encoded states.

Return type:

torch.Tensor

class gfn.preprocessors.Preprocessor(output_dim, target_dtype=None)

Bases: abc.ABC

Base class for state preprocessors.

Preprocessors transform raw state tensors into formats suitable for neural network inputs. They handle the conversion from environment-specific state representations to standardized tensor formats that can be processed by neural networks.

Parameters:
  • output_dim (int | None)

  • target_dtype (torch.dtype | None)

output_dim

The dimensionality of the preprocessed output tensor, which is compatible with the neural network that will be used. If None, the output dimension will not be checked.

__call__(states)

Calls the preprocess method and validates the output shape.

Parameters:

states (gfn.states.States | gfn.states.GraphStates) – The states to preprocess.

Returns:

The preprocessed states as a tensor or GeometricBatch.

Return type:

torch.Tensor | gfn.utils.graphs.GeometricBatch

__repr__()

Returns a string representation of the Preprocessor.

Returns:

A string summary of the Preprocessor.

output_dim
abstract preprocess(states)

Transforms the states to the input format for neural networks.

Parameters:

states (gfn.states.States) – The states to preprocess.

Returns:

A tensor of shape (*batch_shape, output_dim) containing the preprocessed states.

Return type:

torch.Tensor

target_dtype = None