Advanced: Defining a New GFlowNet¶
To define a new GFlowNet, the user needs to define a class which subclasses GFlowNet
and implements the following methods:
sample_trajectories: Sample a specific number of complete trajectories.loss: Compute the loss given the training objects.to_training_samples: Convert trajectories to training samples.
Based on the type of training samples returned by to_training_samples, the user should
define the generic type TrainingSampleType when subclassing GFlowNet. For example,
if the training sample is an instance of Trajectories, the GFlowNet class should be
subclassed as GFlowNet[Trajectories]. Thus, the class definition should look like this:
class MyGFlowNet(GFlowNet[Trajectories]):
...
Example: Flow Matching GFlowNet
Let’s consider the example of the FMGFlowNet class, which is a subclass of
GFlowNet that implements the Flow Matching GFlowNet. The training samples are
pairs of states managed by the StatePairs container:
class FMGFlowNet(GFlowNet[StatePairs[DiscreteStates]]):
...
def to_training_samples(
self, trajectories: Trajectories
) -> StatePairs[DiscreteStates]:
"""Converts a batch of trajectories into a batch of training samples."""
return trajectories.to_state_pairs()
This means that the loss method of FMGFlowNet will receive a
StatePairs[DiscreteStates] object as its training samples argument:
def loss(self, env: DiscreteEnv, states: StatePairs[DiscreteStates]) -> torch.Tensor:
...
Adding New Training Sample Types
If your GFlowNet returns a unique type of training samples, you’ll need to
expand the TrainingSampleType bound. This ensures type-safety and better code
clarity.
Implementing Class Methods
As mentioned earlier, your new GFlowNet must implement the following methods:
sample_trajectories: Sample a specific number of complete trajectories.loss: Compute the loss given the training objects.to_training_samples: Convert trajectories to training samples.
These methods are defined in
src/gfn/gflownet/base.py
and are abstract methods, so they must be implemented in your new GFlowNet. If
your GFlowNet has unique functionality which should be represented as additional
class methods, implement them as required. Remember to document new methods to
ensure other developers understand their purposes and use-cases!
Testing
Remember to create unit tests for your new GFlowNet to ensure it works as intended and integrates seamlessly with other parts of the codebase. This ensures maintainability and reliability of the code!