Modules, Estimators, and Samplers¶
Modules & Estimators¶
Training GFlowNets requires one or multiple Estimators, which is an abstract subclass of torch.nn.Module. In addition to the usual forward function, Estimators need to implement a expected_output_dim attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a to_probability_distribution function.
DiscretePolicyEstimatoris aEstimatorthat defines the policies \(P_F(. \mid s)\) and \(P_B(. \mid s)\) for discrete environments. Whenis_backward=False, the required output dimension isn = env.n_actions, and whenis_backward=True, it isn = env.n_actions - 1. Thesennumbers represent the logits of a Categorical distribution. The correspondingto_probability_distributionfunction transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to \(-\infty\). The function also includes exploration parameters, in order to define a tempered version of \(P_F\), or a mixture of \(P_F\) with a uniform distribution.DiscretePolicyEstimatorwithis_backward=Falsecan be used to represent log-edge-flow estimators \(\log F(s \rightarrow s')\).ScalarEstimatoris a simpleEstimatorwith required output dimension 1. It is useful to define log-state flows \(\log F(s)\).
For non-discrete environments, the user needs to specify their own policies \(P_F\) and \(P_B\). The module, taking as input a batch of states (as a States) object, should return the batched parameters of a torch.Distribution. The distribution depends on the environment. The to_probability_distribution function handles the conversion of the parameter outputs to an actual batched Distribution object, that implements at least the sample and log_prob functions. An example is provided here, for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details.
In general, (and perhaps obviously) the to_probability_distribution method is used to calculate a probability distribution from a policy. Therefore, in order to go off-policy, one needs to modify the computations in this method during sampling. One accomplishes this using policy_kwargs, a dict of kwarg-value pairs which are used by the Estimator when calculating the new policy. In the discrete case, where common settings apply, one can see their use in DiscretePolicyEstimator’s to_probability_distribution method by passing a softmax temperature, sf_bias (a scalar to subtract from the exit action logit) or epsilon which allows for e-greedy style exploration. In the continuous case, it is not possible to foresee the methods used for off-policy exploration (as it depends on the details of the to_probability_distribution method, which is not generic for continuous GFNs), so this must be handled by the user, using custom policy_kwargs.
In all Estimators, note that the input of the forward function is a States object. Meaning that they first need to be transformed to tensors. However, states.tensor does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a Preprocessor object, that is part of the environment. More on this here. The forward pass thus first calls the preprocessor attribute of the environment on States, before performing any transformation. The preprocessor is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor.
For discrete environments, a Tabular module is provided, where a lookup table is used instead of a neural network. Additionally, a UniformPB module is provided, implementing a uniform backward policy. These modules are provided here.
Samplers¶
A Sampler object defines how actions are sampled (sample_actions()) at each state, and trajectories (sample_trajectories()), which can sample a batch of trajectories starting from a given set of initial states or starting from \(s_0\). It requires a Estimator that implements the to_probability_distribution function. For simple off-policy sampling (e.g., epsilon-noisy or tempering), you can pass appropriate policy_kwargs to the Sampler object, which will be used by the Estimator. If you need more complex off-policy sampling, you can subclass the Sampler object, and override the sample_actions and sample_trajectories methods.
Currently, the library provides two samplers.