gfn.preprocessors¶
Attributes¶
Classes¶
Preprocessor for environments with discrete, enumerable states. |
|
Simple preprocessor that returns states without modification. |
|
Preprocessor for grid-structured discrete states with multi-dimensional encoding. |
|
Preprocessor that converts discrete states to one-hot encoded vectors. |
|
Base class for state preprocessors. |
Module Contents¶
- class gfn.preprocessors.EnumPreprocessor(get_states_indices)¶
Bases:
PreprocessorPreprocessor for environments with discrete, enumerable states.
This preprocessor converts discrete states to their unique integer indices, making them suitable for neural network processing. It is designed for environments with a finite number of states where each state can be uniquely identified by an index.
- Parameters:
get_states_indices (GetStatesIndicesFn)
- output_dim¶
Always 1, as states are represented by single indices.
- get_states_indices¶
Function that returns unique indices for states.
- get_states_indices¶
- preprocess(states)¶
Preprocesses the states by returning their unique indices.
- Parameters:
states (gfn.states.DiscreteStates) – The discrete states to preprocess.
- Returns:
A tensor of shape (*batch_shape, 1) containing the unique indices of the states.
- Return type:
torch.Tensor
- gfn.preprocessors.GetStatesIndicesFn¶
- class gfn.preprocessors.IdentityPreprocessor(output_dim, target_dtype=None)¶
Bases:
PreprocessorSimple preprocessor that returns states without modification.
This preprocessor serves as the default preprocessor. It can handle both graph and tensor-based states by returning them as-is.
- Parameters:
output_dim (int | None)
target_dtype (torch.dtype | None)
- output_dim¶
The dimensionality of the input states.
- preprocess(states)¶
Returns the states without any preprocessing.
- Parameters:
states (gfn.states.States) – The states to preprocess.
- Returns:
Tensor or GeometricBatch representing the states.
- Return type:
torch.Tensor | gfn.utils.graphs.GeometricBatch
- class gfn.preprocessors.KHotPreprocessor(height, ndim)¶
Bases:
PreprocessorPreprocessor for grid-structured discrete states with multi-dimensional encoding.
This preprocessor is designed for environments with grid-like state spaces where each dimension can take on a finite number of values. It creates a k-hot encoding where each dimension is one-hot encoded and then concatenated.
- Parameters:
height (int)
ndim (int)
- output_dim¶
The total output dimension (height * ndim).
- height¶
Number of unique values per dimension.
- ndim¶
Number of dimensions in the state space.
- height¶
- ndim¶
- output_dim: int¶
- preprocess(states)¶
Preprocesses the states by creating k-hot encoded vectors.
Each dimension of the state is one-hot encoded and then concatenated into a single vector.
- Parameters:
states (gfn.states.DiscreteStates) – The discrete states to preprocess.
- Returns:
A tensor of shape (*batch_shape, height * ndim) containing k-hot encoded states.
- Return type:
torch.Tensor
Note
This preprocessor only works for integer state tensors.
- class gfn.preprocessors.OneHotPreprocessor(n_states, get_states_indices)¶
Bases:
PreprocessorPreprocessor that converts discrete states to one-hot encoded vectors.
This preprocessor is designed for environments with enumerable states where each state is represented as a one-hot vector. The output dimension equals the total number of possible states.
- Parameters:
n_states (int)
get_states_indices (GetStatesIndicesFn)
- output_dim¶
The total number of states in the environment.
- get_states_indices¶
Function that returns unique indices for states.
- get_states_indices¶
- output_dim: int¶
- preprocess(states)¶
Preprocesses the states by converting them to one-hot encoded vectors.
- Parameters:
states (gfn.states.DiscreteStates) – The discrete states to preprocess.
- Returns:
A tensor of shape (*batch_shape, n_states) containing one-hot encoded states.
- Return type:
torch.Tensor
- class gfn.preprocessors.Preprocessor(output_dim, target_dtype=None)¶
Bases:
abc.ABCBase class for state preprocessors.
Preprocessors transform raw state tensors into formats suitable for neural network inputs. They handle the conversion from environment-specific state representations to standardized tensor formats that can be processed by neural networks.
- Parameters:
output_dim (int | None)
target_dtype (torch.dtype | None)
- output_dim¶
The dimensionality of the preprocessed output tensor, which is compatible with the neural network that will be used. If None, the output dimension will not be checked.
- __call__(states)¶
Calls the preprocess method and validates the output shape.
- Parameters:
states (gfn.states.States | gfn.states.GraphStates) – The states to preprocess.
- Returns:
The preprocessed states as a tensor or GeometricBatch.
- Return type:
torch.Tensor | gfn.utils.graphs.GeometricBatch
- __repr__()¶
Returns a string representation of the Preprocessor.
- Returns:
A string summary of the Preprocessor.
- output_dim¶
- abstract preprocess(states)¶
Transforms the states to the input format for neural networks.
- Parameters:
states (gfn.states.States) – The states to preprocess.
- Returns:
A tensor of shape (*batch_shape, output_dim) containing the preprocessed states.
- Return type:
torch.Tensor
- target_dtype = None¶