gfn.gym.helpers.box_polar_utils¶
Polar (legacy) estimators and distributions for the Box environment.
Note: The numerical implementation has been improved from the original published version. Key changes from the reference code: - arccos replaced with atan2 for improved numerical stability - Removed double-precision (torch.double) casts that were previously needed - BoxPFMLP.forward rewritten with torch.where for torch.compile compatibility
The mathematical semantics are unchanged: actions are polar vectors with L2 norm equal to delta, parameterised via quarter-circle Beta mixtures.
Attributes¶
Classes¶
Estimator for P_B for the Box environment. |
|
A MLP for the backward policy of the Box environment. |
|
Uniform backward policy for the polar Box environment. |
|
Estimator for P_F for the Box environment. |
|
A MLP for the forward policy of the Box environment. |
|
A MLP for the state flow function of the Box environment. |
|
A wrapper that combines QuarterDisk and QuarterCircleWithExit. |
|
Represents distributions on quarter circles. |
|
Extends QuarterCircle with an exit action. |
|
Represents a distribution on the northeastern quarter disk. |
Functions¶
|
Splits the module output into the expected parameter sets. |
Module Contents¶
- class gfn.gym.helpers.box_polar_utils.BoxPBEstimator(env, module, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinEstimator for P_B for the Box environment.
This estimator uses the QuarterCircle(northeastern=False) distribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components (int)
min_concentration (float)
max_concentration (float)
debug (bool)
- n_components¶
The number of components for the mixture.
- min_concentration¶
The minimum concentration for the Beta distributions.
- max_concentration¶
The maximum concentration for the Beta distributions.
- delta¶
The radius of the quarter disk.
- delta: float¶
- property expected_output_dim: int¶
Returns the expected output dimension of the module.
- Return type:
int
- max_concentration: float¶
- min_concentration: float¶
- module¶
- n_components: int¶
- to_probability_distribution(states, module_output)¶
Converts the module output to a probability distribution.
- Parameters:
states (gfn.states.States) – the states for which to convert the module output to a probability distribution.
module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).
- Returns:
The probability distribution for the states.
- Return type:
torch.distributions.Distribution
- classmethod uniform(env, n_components=1, **kwargs)¶
Create an estimator with a fixed (non-learned) uniform backward policy.
Uses
alpha=beta=1(uniform Beta) by settingskip_normalization=Trueso the constant module output is used directly as concentration parameters.- Parameters:
env (gfn.gym.box.BoxPolar) – The Box environment.
n_components (int) – Number of mixture components (default 1).
**kwargs – Extra keyword arguments forwarded to the constructor.
- Returns:
A
BoxPBEstimatorwhose module has no learnable parameters.- Return type:
- class gfn.gym.helpers.box_polar_utils.BoxPBMLP(hidden_dim, n_hidden_layers, n_components, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the backward policy of the Box environment.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components (int)
kwargs (Any)
- n_components¶
The number of components for each distribution parameter.
- Type:
int
- _input_dim: int¶
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).
- Returns:
A tensor of shape (*batch_shape, 3 * n_components).
- Return type:
torch.Tensor
- n_components: int¶
- class gfn.gym.helpers.box_polar_utils.BoxPBUniform¶
Bases:
gfn.utils.modules.UniformModuleUniform backward policy for the polar Box environment.
Backward-compatible alias for
UniformModule(output_dim=3, input_dim=2, fill_value=1.0, skip_normalization=True). Used withQuarterCircle, it leads to a uniform (alpha=beta=1) Beta distribution over parents. PreferUniformModulefor new code.
- class gfn.gym.helpers.box_polar_utils.BoxPFEstimator(env, module, n_components_s0, n_components, min_concentration=0.1, max_concentration=2.0, debug=False)¶
Bases:
gfn.estimators.Estimator,gfn.estimators.PolicyMixinEstimator for P_F for the Box environment.
This estimator uses the DistributionWrapper distribution.
- Parameters:
env (gfn.gym.box.BoxPolar)
module (torch.nn.Module)
n_components_s0 (int)
n_components (int)
min_concentration (float)
max_concentration (float)
debug (bool)
- n_components_s0¶
The number of components for s0.
- n_components¶
The number of components for non-s0 states.
- min_concentration¶
The minimum concentration for the Beta distributions.
- max_concentration¶
The maximum concentration for the Beta distributions.
- delta¶
The radius of the quarter disk.
- epsilon¶
The epsilon value.
- _n_comp_max: int¶
- delta: float¶
- epsilon: float¶
- property expected_output_dim: int¶
Returns the expected output dimension of the module.
- Return type:
int
- max_concentration: float¶
- min_concentration: float¶
- n_components: int¶
- n_components_s0: int¶
- to_probability_distribution(states, module_output)¶
Converts the module output to a probability distribution.
- Parameters:
states (gfn.states.States) – the states for which to convert the module output to a probability distribution.
module_output (torch.Tensor) – the output of the module for the states as a tensor of shape (*batch_shape, output_dim).
- Returns:
The probability distribution for the states.
- Return type:
torch.distributions.Distribution
- class gfn.gym.helpers.box_polar_utils.BoxPFMLP(hidden_dim, n_hidden_layers, n_components_s0, n_components, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the forward policy of the Box environment.
- Parameters:
hidden_dim (int)
n_hidden_layers (int)
n_components_s0 (int)
n_components (int)
kwargs (Any)
- n_components_s0¶
The number of components for s0.
- Type:
int
- n_components¶
The number of components for non-s0 states.
- Type:
int
- PFs0¶
The parameters for the s0 distribution.
- Type:
nn.Parameter
- PFs0: torch.nn.Parameter¶
- _input_dim: int¶
- _n_comp_max: int¶
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, 2).
- Returns:
A tensor of shape (*batch_shape, 1 + 5 * max_n_components).
- Return type:
torch.Tensor
- n_components: int¶
- n_components_s0: int¶
- class gfn.gym.helpers.box_polar_utils.BoxStateFlowModule(logZ_value, **kwargs)¶
Bases:
gfn.utils.modules.MLPA MLP for the state flow function of the Box environment.
- Parameters:
logZ_value (torch.Tensor)
kwargs (Any)
- logZ_value¶
The log partition function value.
- Type:
nn.Parameter
- forward(preprocessed_states)¶
Computes the forward pass of the neural network.
- Parameters:
preprocessed_states (torch.Tensor) – The tensor states of shape (*batch_shape, input_dim).
- Returns:
A tensor of shape (*batch_shape, output_dim).
- Return type:
torch.Tensor
- logZ_value: torch.nn.Parameter¶
- gfn.gym.helpers.box_polar_utils.CLAMP: float¶
- class gfn.gym.helpers.box_polar_utils.DistributionWrapper(states, delta, epsilon, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta, exit_probability, n_components, n_components_s0, debug=False)¶
Bases:
torch.distributions.DistributionA wrapper that combines QuarterDisk and QuarterCircleWithExit.
- Parameters:
states (gfn.states.States)
delta (float)
epsilon (float)
mixture_logits (torch.Tensor)
alpha_r (torch.Tensor)
beta_r (torch.Tensor)
alpha_theta (torch.Tensor)
beta_theta (torch.Tensor)
exit_probability (torch.Tensor)
n_components (int)
n_components_s0 (int)
debug (bool)
- idx_is_initial¶
The indices of the initial states.
- idx_not_initial¶
The indices of the non-initial states.
- quarter_disk¶
The QuarterDisk distribution.
- quarter_circ¶
The QuarterCircleWithExit distribution.
- _output_shape: tuple[int, Ellipsis]¶
- debug = False¶
- idx_is_initial: torch.Tensor¶
- idx_not_initial: torch.Tensor¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
A tensor of shape (*batch_shape) containing the log probabilities.
- Return type:
torch.Tensor
- quarter_circ: QuarterCircleWithExit | None¶
- quarter_disk: QuarterDisk | None¶
- sample(sample_shape=Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
A tensor of shape (sample_shape + self._output_shape) containing the sampled actions.
- Return type:
torch.Tensor
- gfn.gym.helpers.box_polar_utils.PI_2: float¶
- gfn.gym.helpers.box_polar_utils.PI_2_INV: float¶
- class gfn.gym.helpers.box_polar_utils.QuarterCircle(delta, northeastern, centers, mixture_logits, alpha, beta, debug=False)¶
Bases:
torch.distributions.DistributionRepresents distributions on quarter circles.
The distributions are Mixture of Beta distributions on the possible angle range.
When a state is of norm <= delta, and northeastern=False, then the distribution is a Dirac at the state (i.e. the only possible parent is s_0).
Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py
This is useful for the Box environment.
- Parameters:
delta (float)
northeastern (bool)
centers (gfn.states.States)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
debug (bool)
- delta¶
The radius of the quarter disk.
- northeastern¶
Whether the quarter disk is northeastern or southwestern.
- n_states¶
The number of states.
- n_components¶
The number of components in the mixture.
- centers¶
The centers of the distribution.
- base_dist¶
The base distribution.
- min_angles¶
The minimum angles.
- max_angles¶
The maximum angles.
- base_dist: torch.distributions.MixtureSameFamily¶
- centers: gfn.states.States¶
- debug = False¶
- delta: float¶
- get_min_and_max_angles()¶
Computes the minimum and maximum angles for the distribution.
- Returns:
A tuple of two tensors of shape (n_states,) containing the minimum and maximum angles, respectively.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- max_angles: torch.Tensor¶
- min_angles: torch.Tensor¶
- n_components: int¶
- n_states: int¶
- northeastern: bool¶
- sample(sample_shape=Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
The sampled actions of shape (sample_shape, n_states, 2).
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_polar_utils.QuarterCircleWithExit(delta, centers, exit_probability, mixture_logits, alpha, beta, epsilon=0.0001)¶
Bases:
torch.distributions.DistributionExtends QuarterCircle with an exit action.
When sampling, with probability exit_probability, the exit_action [-inf, -inf] is sampled. The log_prob function is adjusted accordingly.
- Parameters:
delta (float)
centers (gfn.states.States)
exit_probability (torch.Tensor)
mixture_logits (torch.Tensor)
alpha (torch.Tensor)
beta (torch.Tensor)
epsilon (float)
- delta¶
The radius of the quarter disk.
- epsilon¶
The epsilon value to consider the state as being at the border of the square.
- centers¶
The centers of the distribution.
- dist_without_exit¶
The distribution without the exit action.
- exit_probability¶
The probability of exiting.
- exit_action¶
The exit action.
- n_states¶
The number of states.
- centers: gfn.states.States¶
- delta: float¶
- dist_without_exit: QuarterCircle¶
- epsilon: float¶
- exit_action: torch.Tensor¶
- exit_probability: torch.Tensor¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- n_states: int¶
- sample()¶
Samples from the distribution.
- Returns:
The sampled actions with shape (n_states, 2).
- Return type:
torch.Tensor
- class gfn.gym.helpers.box_polar_utils.QuarterDisk(delta, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta)¶
Bases:
torch.distributions.DistributionRepresents a distribution on the northeastern quarter disk.
The radius and the angle follow Mixture of Betas distributions.
Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py
This is useful for the Box environment.
- Parameters:
delta (float)
mixture_logits (torch.Tensor)
alpha_r (torch.Tensor)
beta_r (torch.Tensor)
alpha_theta (torch.Tensor)
beta_theta (torch.Tensor)
- delta¶
The radius of the quarter disk.
- mixture_logits¶
The logits of the mixture of Beta distributions.
- base_r_dist¶
The base distribution for the radius.
- base_theta_dist¶
The base distribution for the angle.
- n_components¶
The number of components in the mixture.
- base_r_dist: torch.distributions.MixtureSameFamily¶
- base_theta_dist: torch.distributions.MixtureSameFamily¶
- debug = False¶
- delta: float¶
- log_prob(sampled_actions)¶
Computes the log probability of the sampled actions.
- Parameters:
sampled_actions (torch.Tensor) – Tensor of shape (*batch_shape, 2) with the actions to compute the log probability of.
- Returns:
The log probability of the sampled actions as a tensor of shape batch_shape.
- Return type:
torch.Tensor
- mixture_logits: torch.Tensor¶
- n_components: int¶
- sample(sample_shape=torch.Size())¶
Samples from the distribution.
- Parameters:
sample_shape (torch.Size) – the shape of the samples to generate.
- Returns:
The sampled actions of shape (sample_shape, 2).
- Return type:
torch.Tensor
- gfn.gym.helpers.box_polar_utils.split_PF_module_output(output, n_comp_max)¶
Splits the module output into the expected parameter sets.
- Parameters:
output (torch.Tensor) – the module_output from the P_F model as a tensor of shape (*batch_shape, output_dim).
n_comp_max (int) – the larger number of the two n_components and n_components_s0.
- Returns:
exit_probability: A probability unique to QuarterCircleWithExit.
- mixture_logits: Parameters shared by QuarterDisk and
QuarterCircleWithExit.
alpha_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.
beta_r: Parameters shared by QuarterDisk and QuarterCircleWithExit.
alpha_theta: Parameters unique to QuarterDisk.
beta_theta: Parameters unique to QuarterDisk.
- Return type:
A tuple containing