gfn.gym.helpers.box_polar_utils =============================== .. py:module:: gfn.gym.helpers.box_polar_utils .. autoapi-nested-parse:: Polar (legacy) estimators and distributions for the Box environment. Note: The numerical implementation has been improved from the original published version. Key changes from the reference code: - arccos replaced with atan2 for improved numerical stability - Removed double-precision (torch.double) casts that were previously needed - BoxPFMLP.forward rewritten with torch.where for torch.compile compatibility The mathematical semantics are unchanged: actions are polar vectors with L2 norm equal to delta, parameterised via quarter-circle Beta mixtures. Attributes ---------- .. autoapisummary:: gfn.gym.helpers.box_polar_utils.CLAMP gfn.gym.helpers.box_polar_utils.PI_2 gfn.gym.helpers.box_polar_utils.PI_2_INV Classes ------- .. autoapisummary:: gfn.gym.helpers.box_polar_utils.BoxPBEstimator gfn.gym.helpers.box_polar_utils.BoxPBMLP gfn.gym.helpers.box_polar_utils.BoxPBUniform gfn.gym.helpers.box_polar_utils.BoxPFEstimator gfn.gym.helpers.box_polar_utils.BoxPFMLP gfn.gym.helpers.box_polar_utils.BoxStateFlowModule gfn.gym.helpers.box_polar_utils.DistributionWrapper gfn.gym.helpers.box_polar_utils.QuarterCircle gfn.gym.helpers.box_polar_utils.QuarterCircleWithExit gfn.gym.helpers.box_polar_utils.QuarterDisk Functions --------- .. autoapisummary:: gfn.gym.helpers.box_polar_utils.split_PF_module_output Module Contents --------------- .. py:class:: BoxPBEstimator(env, module, n_components, min_concentration = 0.1, max_concentration = 2.0, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. .. attribute:: n_components The number of components for the mixture. .. attribute:: min_concentration The minimum concentration for the Beta distributions. .. attribute:: max_concentration The maximum concentration for the Beta distributions. .. attribute:: delta The radius of the quarter disk. .. py:attribute:: delta :type: float .. py:property:: expected_output_dim :type: int Returns the expected output dimension of the module. .. py:attribute:: max_concentration :type: float .. py:attribute:: min_concentration :type: float .. py:attribute:: module .. py:attribute:: n_components :type: int .. py:method:: to_probability_distribution(states, module_output) Converts the module output to a probability distribution. :param states: the states for which to convert the module output to a probability distribution. :param module_output: the output of the module for the states as a tensor of shape `(*batch_shape, output_dim)`. :returns: The probability distribution for the states. .. py:method:: uniform(env, n_components = 1, **kwargs) :classmethod: Create an estimator with a fixed (non-learned) uniform backward policy. Uses ``alpha=beta=1`` (uniform Beta) by setting ``skip_normalization=True`` so the constant module output is used directly as concentration parameters. :param env: The Box environment. :param n_components: Number of mixture components (default 1). :param \*\*kwargs: Extra keyword arguments forwarded to the constructor. :returns: A ``BoxPBEstimator`` whose module has no learnable parameters. .. py:class:: BoxPBMLP(hidden_dim, n_hidden_layers, n_components, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the backward policy of the Box environment. .. attribute:: n_components The number of components for each distribution parameter. :type: int .. py:attribute:: _input_dim :type: int .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, 2)`. :returns: A tensor of shape `(*batch_shape, 3 * n_components)`. .. py:attribute:: n_components :type: int .. py:class:: BoxPBUniform Bases: :py:obj:`gfn.utils.modules.UniformModule` Uniform backward policy for the polar Box environment. Backward-compatible alias for ``UniformModule(output_dim=3, input_dim=2, fill_value=1.0, skip_normalization=True)``. Used with ``QuarterCircle``, it leads to a uniform (alpha=beta=1) Beta distribution over parents. Prefer ``UniformModule`` for new code. .. py:class:: BoxPFEstimator(env, module, n_components_s0, n_components, min_concentration = 0.1, max_concentration = 2.0, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. .. attribute:: n_components_s0 The number of components for s0. .. attribute:: n_components The number of components for non-s0 states. .. attribute:: min_concentration The minimum concentration for the Beta distributions. .. attribute:: max_concentration The maximum concentration for the Beta distributions. .. attribute:: delta The radius of the quarter disk. .. attribute:: epsilon The epsilon value. .. py:attribute:: _n_comp_max :type: int .. py:attribute:: delta :type: float .. py:attribute:: epsilon :type: float .. py:property:: expected_output_dim :type: int Returns the expected output dimension of the module. .. py:attribute:: max_concentration :type: float .. py:attribute:: min_concentration :type: float .. py:attribute:: n_components :type: int .. py:attribute:: n_components_s0 :type: int .. py:method:: to_probability_distribution(states, module_output) Converts the module output to a probability distribution. :param states: the states for which to convert the module output to a probability distribution. :param module_output: the output of the module for the states as a tensor of shape `(*batch_shape, output_dim)`. :returns: The probability distribution for the states. .. py:class:: BoxPFMLP(hidden_dim, n_hidden_layers, n_components_s0, n_components, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the forward policy of the Box environment. .. attribute:: n_components_s0 The number of components for s0. :type: int .. attribute:: n_components The number of components for non-s0 states. :type: int .. attribute:: PFs0 The parameters for the s0 distribution. :type: nn.Parameter .. py:attribute:: PFs0 :type: torch.nn.Parameter .. py:attribute:: _input_dim :type: int .. py:attribute:: _n_comp_max :type: int .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, 2)`. :returns: A tensor of shape `(*batch_shape, 1 + 5 * max_n_components)`. .. py:attribute:: n_components :type: int .. py:attribute:: n_components_s0 :type: int .. py:class:: BoxStateFlowModule(logZ_value, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the state flow function of the Box environment. .. attribute:: logZ_value The log partition function value. :type: nn.Parameter .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, input_dim)`. :returns: A tensor of shape `(*batch_shape, output_dim)`. .. py:attribute:: logZ_value :type: torch.nn.Parameter .. py:data:: CLAMP :type: float .. py:class:: DistributionWrapper(states, delta, epsilon, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta, exit_probability, n_components, n_components_s0, debug = False) Bases: :py:obj:`torch.distributions.Distribution` A wrapper that combines `QuarterDisk` and `QuarterCircleWithExit`. .. attribute:: idx_is_initial The indices of the initial states. .. attribute:: idx_not_initial The indices of the non-initial states. .. attribute:: quarter_disk The `QuarterDisk` distribution. .. attribute:: quarter_circ The `QuarterCircleWithExit` distribution. .. py:attribute:: _output_shape :type: tuple[int, Ellipsis] .. py:attribute:: debug :value: False .. py:attribute:: idx_is_initial :type: torch.Tensor .. py:attribute:: idx_not_initial :type: torch.Tensor .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: A tensor of shape `(*batch_shape)` containing the log probabilities. .. py:attribute:: quarter_circ :type: Optional[QuarterCircleWithExit] .. py:attribute:: quarter_disk :type: Optional[QuarterDisk] .. py:method:: sample(sample_shape = Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: A tensor of shape `(sample_shape + self._output_shape)` containing the sampled actions. .. py:data:: PI_2 :type: float .. py:data:: PI_2_INV :type: float .. py:class:: QuarterCircle(delta, northeastern, centers, mixture_logits, alpha, beta, debug = False) Bases: :py:obj:`torch.distributions.Distribution` Represents distributions on quarter circles. The distributions are Mixture of Beta distributions on the possible angle range. When a state is of norm <= delta, and northeastern=False, then the distribution is a Dirac at the state (i.e. the only possible parent is s_0). Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py This is useful for the `Box` environment. .. attribute:: delta The radius of the quarter disk. .. attribute:: northeastern Whether the quarter disk is northeastern or southwestern. .. attribute:: n_states The number of states. .. attribute:: n_components The number of components in the mixture. .. attribute:: centers The centers of the distribution. .. attribute:: base_dist The base distribution. .. attribute:: min_angles The minimum angles. .. attribute:: max_angles The maximum angles. .. py:attribute:: base_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: centers :type: gfn.states.States .. py:attribute:: debug :value: False .. py:attribute:: delta :type: float .. py:method:: get_min_and_max_angles() Computes the minimum and maximum angles for the distribution. :returns: A tuple of two tensors of shape `(n_states,)` containing the minimum and maximum angles, respectively. .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: max_angles :type: torch.Tensor .. py:attribute:: min_angles :type: torch.Tensor .. py:attribute:: n_components :type: int .. py:attribute:: n_states :type: int .. py:attribute:: northeastern :type: bool .. py:method:: sample(sample_shape = Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: The sampled actions of shape `(sample_shape, n_states, 2)`. .. py:class:: QuarterCircleWithExit(delta, centers, exit_probability, mixture_logits, alpha, beta, epsilon = 0.0001) Bases: :py:obj:`torch.distributions.Distribution` Extends `QuarterCircle` with an exit action. When sampling, with probability `exit_probability`, the `exit_action` `[-inf, -inf]` is sampled. The `log_prob` function is adjusted accordingly. .. attribute:: delta The radius of the quarter disk. .. attribute:: epsilon The epsilon value to consider the state as being at the border of the square. .. attribute:: centers The centers of the distribution. .. attribute:: dist_without_exit The distribution without the exit action. .. attribute:: exit_probability The probability of exiting. .. attribute:: exit_action The exit action. .. attribute:: n_states The number of states. .. py:attribute:: centers :type: gfn.states.States .. py:attribute:: delta :type: float .. py:attribute:: dist_without_exit :type: QuarterCircle .. py:attribute:: epsilon :type: float .. py:attribute:: exit_action :type: torch.Tensor .. py:attribute:: exit_probability :type: torch.Tensor .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: n_states :type: int .. py:method:: sample() Samples from the distribution. :returns: The sampled actions with shape `(n_states, 2)`. .. py:class:: QuarterDisk(delta, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta) Bases: :py:obj:`torch.distributions.Distribution` Represents a distribution on the northeastern quarter disk. The radius and the angle follow Mixture of Betas distributions. Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py This is useful for the `Box` environment. .. attribute:: delta The radius of the quarter disk. .. attribute:: mixture_logits The logits of the mixture of Beta distributions. .. attribute:: base_r_dist The base distribution for the radius. .. attribute:: base_theta_dist The base distribution for the angle. .. attribute:: n_components The number of components in the mixture. .. py:attribute:: base_r_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: base_theta_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: debug :value: False .. py:attribute:: delta :type: float .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: mixture_logits :type: torch.Tensor .. py:attribute:: n_components :type: int .. py:method:: sample(sample_shape = torch.Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: The sampled actions of shape `(sample_shape, 2)`. .. py:function:: split_PF_module_output(output, n_comp_max) Splits the module output into the expected parameter sets. :param output: the module_output from the P_F model as a tensor of shape `(*batch_shape, output_dim)`. :param n_comp_max: the larger number of the two `n_components` and `n_components_s0`. :returns: - `exit_probability`: A probability unique to `QuarterCircleWithExit`. - `mixture_logits`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `alpha_r`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `beta_r`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `alpha_theta`: Parameters unique to `QuarterDisk`. - `beta_theta`: Parameters unique to `QuarterDisk`. :rtype: A tuple containing