gfn.gym.helpers.box_utils ========================= .. py:module:: gfn.gym.helpers.box_utils .. autoapi-nested-parse:: Backward-compatibility shim for Box environment utilities. Import from ``box_cartesian_utils`` or ``box_polar_utils`` directly for new code. Classes ------- .. autoapisummary:: gfn.gym.helpers.box_utils.BoxCartesianDistribution gfn.gym.helpers.box_utils.BoxCartesianPBDistribution gfn.gym.helpers.box_utils.BoxCartesianPBEstimator gfn.gym.helpers.box_utils.BoxCartesianPBMLP gfn.gym.helpers.box_utils.BoxCartesianPFEstimator gfn.gym.helpers.box_utils.BoxCartesianPFMLP gfn.gym.helpers.box_utils.BoxPBEstimator gfn.gym.helpers.box_utils.BoxPBMLP gfn.gym.helpers.box_utils.BoxPBUniform gfn.gym.helpers.box_utils.BoxPFEstimator gfn.gym.helpers.box_utils.BoxPFMLP gfn.gym.helpers.box_utils.BoxStateFlowModule gfn.gym.helpers.box_utils.DistributionWrapper gfn.gym.helpers.box_utils.QuarterCircle gfn.gym.helpers.box_utils.QuarterCircleWithExit gfn.gym.helpers.box_utils.QuarterDisk gfn.gym.helpers.box_utils.UniformBoxCartesianPBModule Functions --------- .. autoapisummary:: gfn.gym.helpers.box_utils.split_PF_module_output Module Contents --------------- .. py:class:: BoxCartesianDistribution(states, exit_logits, mixture_logits, alpha, beta, delta, epsilon = 1e-06, temperature = 1.0) Bases: :py:obj:`_BetaMixtureMixin`, :py:obj:`torch.distributions.Distribution` Cartesian increment distribution for Box environment. Uses MixtureSameFamily(Categorical, Beta) per dimension for sampling increments. Much simpler than polar coordinates - samples relative increments per dimension and converts to absolute using: action = min_incr + r * (max_range). .. attribute:: delta Minimum step size. .. attribute:: epsilon Small value for numerical stability. .. py:attribute:: alpha .. py:attribute:: arg_constraints :type: dict .. py:attribute:: at_boundary .. py:attribute:: beta .. py:attribute:: delta .. py:attribute:: epsilon :value: 1e-06 .. py:attribute:: exit_logits_scaled .. py:attribute:: is_s0 .. py:method:: log_prob(actions) Compute log probability using Cartesian per-dimension approach. .. py:attribute:: log_weights .. py:attribute:: max_range .. py:attribute:: min_incr .. py:attribute:: n_dim .. py:method:: sample(sample_shape = Size()) Sample actions using Cartesian per-dimension increments. .. py:attribute:: states .. py:class:: BoxCartesianPBDistribution(states, bts_logits, mixture_logits, alpha, beta, delta, epsilon = 1e-06, temperature = 1.0) Bases: :py:obj:`_BetaMixtureMixin`, :py:obj:`torch.distributions.Distribution` Backward Cartesian distribution for Box environment. In torchgfn's design, the source state is the origin [0, 0]. The BTS (back-to-source) action moves directly to s0 by setting action = state. WHY THE LEARNED BTS BERNOULLI IS CRITICAL ------------------------------------------ The Trajectory Balance (TB) loss is: L = (log P_F(τ) + log Z - log P_B(τ) - log R(x_T))^2 The BTS step (x_1 → s0) is the *last* step of every backward trajectory. Before this fix, BTS was always deterministic (forced): log P_B(BTS | x_1) = 0. A constant zero drops out of the gradient, so P_B received no gradient from the BTS step — regardless of trajectory length. For 1-step trajectories (s0 → x_T, then immediately BTS back to s0), P_B was *entirely* gradient-free: log P_B(τ) = 0 for every such trajectory, making P_B invisible to the TB loss. This is particularly harmful because the reward landscape is dominated by states reachable from s0 in one step (the high-reward ring at |x - 0.5| ∈ (0.3, 0.4)). With a learned Bernoulli P(BTS | x): - log P_B(BTS | x_1) = log P(BTS=1 | x_1) — gradient flows into P_B - log P_B(~BTS | x_1) = log P(BTS=0 | x_1) — also receives gradient This closes the TB loop fully: P_B now has an incentive to assign higher probability to BTS from states that are indeed close to the reward modes, and lower probability from states that are far away. FORCED vs. STOCHASTIC BTS ------------------------- When any dimension of the state is <= delta, the valid backward increment range [delta, state[d]] is empty for that dimension. BTS is the *only* valid action, so it is forced (log_prob = 0, deterministic). For all other states, BTS is an optional choice sampled from the learned Bernoulli. .. py:attribute:: alpha .. py:attribute:: any_dim_near_origin .. py:attribute:: arg_constraints :type: dict .. py:attribute:: beta .. py:attribute:: bts_logits_scaled .. py:attribute:: delta .. py:attribute:: dim_near_origin .. py:attribute:: epsilon :value: 1e-06 .. py:method:: log_prob(actions) Compute log probability of backward actions. - Forced BTS (any_dim_near_origin): log_prob = 0 (deterministic). - Stochastic BTS: log_prob = log P(BTS=1 | s) from Bernoulli. - Non-BTS: log_prob = log P(BTS=0 | s) + log_p_beta + log_jacobian. .. py:attribute:: log_weights .. py:attribute:: max_range .. py:attribute:: min_incr .. py:attribute:: n_dim .. py:method:: sample(sample_shape = Size()) Sample backward actions. BTS (action = state) is forced when near origin; otherwise sampled from the learned Bernoulli. Non-BTS samples come from Beta mixture. .. py:attribute:: states .. py:class:: BoxCartesianPBEstimator(env, module, n_components, min_concentration = 0.1, max_concentration = 100.0, numerical_epsilon = 1e-06, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Simplified PB estimator using Cartesian increments with back-to-source. .. py:attribute:: delta .. py:attribute:: epsilon .. py:property:: expected_output_dim :type: int bts_logit + (weights + alpha + beta) * n_dim * n_comp. :type: Expected output dimension .. py:attribute:: max_concentration :value: 100.0 .. py:attribute:: min_concentration :value: 0.1 .. py:attribute:: n_components .. py:attribute:: n_dim :value: 2 .. py:attribute:: numerical_epsilon :value: 1e-06 .. py:attribute:: temperature :type: float :value: 1.0 .. py:method:: to_probability_distribution(states, module_output) Convert module output to backward probability distribution. .. py:method:: uniform(env, n_components, **kwargs) :classmethod: Create an estimator with a fixed (non-learned) uniform backward policy. :param env: The Box environment. :param n_components: Number of mixture components. :param \*\*kwargs: Extra keyword arguments forwarded to the constructor. :returns: A ``BoxCartesianPBEstimator`` whose module has no learnable parameters. .. py:class:: BoxCartesianPBMLP(hidden_dim, n_hidden_layers, n_components, n_dim = 2, **kwargs) Bases: :py:obj:`_BoxCartesianMLP` MLP for Box backward policy. First output is back-to-start logit. .. py:class:: BoxCartesianPFEstimator(env, module, n_components, min_concentration = 0.1, max_concentration = 100.0, numerical_epsilon = 1e-06, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Simplified PF estimator using Cartesian increments. Much simpler than BoxPFEstimator - uses a single MLP and BoxCartesianDistribution. .. py:attribute:: delta .. py:attribute:: epsilon .. py:property:: expected_output_dim :type: int exit_logit + (weights + alpha + beta) * n_dim * n_comp. :type: Expected output dimension .. py:attribute:: max_concentration :value: 100.0 .. py:attribute:: min_concentration :value: 0.1 .. py:attribute:: n_components .. py:attribute:: n_dim :value: 2 .. py:attribute:: numerical_epsilon :value: 1e-06 .. py:attribute:: temperature :type: float :value: 1.0 .. py:method:: to_probability_distribution(states, module_output) Convert module output to a probability distribution. :param states: The states. :param module_output: Output from the module, shape (batch, expected_output_dim). :returns: BoxCartesianDistribution instance. .. py:class:: BoxCartesianPFMLP(hidden_dim, n_hidden_layers, n_components, n_dim = 2, **kwargs) Bases: :py:obj:`_BoxCartesianMLP` MLP for Box forward policy. First output is exit logit. .. py:class:: BoxPBEstimator(env, module, n_components, min_concentration = 0.1, max_concentration = 2.0, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Estimator for `P_B` for the Box environment. This estimator uses the `QuarterCircle(northeastern=False)` distribution. .. attribute:: n_components The number of components for the mixture. .. attribute:: min_concentration The minimum concentration for the Beta distributions. .. attribute:: max_concentration The maximum concentration for the Beta distributions. .. attribute:: delta The radius of the quarter disk. .. py:attribute:: delta :type: float .. py:property:: expected_output_dim :type: int Returns the expected output dimension of the module. .. py:attribute:: max_concentration :type: float .. py:attribute:: min_concentration :type: float .. py:attribute:: module .. py:attribute:: n_components :type: int .. py:method:: to_probability_distribution(states, module_output) Converts the module output to a probability distribution. :param states: the states for which to convert the module output to a probability distribution. :param module_output: the output of the module for the states as a tensor of shape `(*batch_shape, output_dim)`. :returns: The probability distribution for the states. .. py:method:: uniform(env, n_components = 1, **kwargs) :classmethod: Create an estimator with a fixed (non-learned) uniform backward policy. Uses ``alpha=beta=1`` (uniform Beta) by setting ``skip_normalization=True`` so the constant module output is used directly as concentration parameters. :param env: The Box environment. :param n_components: Number of mixture components (default 1). :param \*\*kwargs: Extra keyword arguments forwarded to the constructor. :returns: A ``BoxPBEstimator`` whose module has no learnable parameters. .. py:class:: BoxPBMLP(hidden_dim, n_hidden_layers, n_components, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the backward policy of the Box environment. .. attribute:: n_components The number of components for each distribution parameter. :type: int .. py:attribute:: _input_dim :type: int .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, 2)`. :returns: A tensor of shape `(*batch_shape, 3 * n_components)`. .. py:attribute:: n_components :type: int .. py:class:: BoxPBUniform Bases: :py:obj:`gfn.utils.modules.UniformModule` Uniform backward policy for the polar Box environment. Backward-compatible alias for ``UniformModule(output_dim=3, input_dim=2, fill_value=1.0, skip_normalization=True)``. Used with ``QuarterCircle``, it leads to a uniform (alpha=beta=1) Beta distribution over parents. Prefer ``UniformModule`` for new code. .. py:class:: BoxPFEstimator(env, module, n_components_s0, n_components, min_concentration = 0.1, max_concentration = 2.0, debug = False) Bases: :py:obj:`gfn.estimators.Estimator`, :py:obj:`gfn.estimators.PolicyMixin` Estimator for `P_F` for the Box environment. This estimator uses the `DistributionWrapper` distribution. .. attribute:: n_components_s0 The number of components for s0. .. attribute:: n_components The number of components for non-s0 states. .. attribute:: min_concentration The minimum concentration for the Beta distributions. .. attribute:: max_concentration The maximum concentration for the Beta distributions. .. attribute:: delta The radius of the quarter disk. .. attribute:: epsilon The epsilon value. .. py:attribute:: _n_comp_max :type: int .. py:attribute:: delta :type: float .. py:attribute:: epsilon :type: float .. py:property:: expected_output_dim :type: int Returns the expected output dimension of the module. .. py:attribute:: max_concentration :type: float .. py:attribute:: min_concentration :type: float .. py:attribute:: n_components :type: int .. py:attribute:: n_components_s0 :type: int .. py:method:: to_probability_distribution(states, module_output) Converts the module output to a probability distribution. :param states: the states for which to convert the module output to a probability distribution. :param module_output: the output of the module for the states as a tensor of shape `(*batch_shape, output_dim)`. :returns: The probability distribution for the states. .. py:class:: BoxPFMLP(hidden_dim, n_hidden_layers, n_components_s0, n_components, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the forward policy of the Box environment. .. attribute:: n_components_s0 The number of components for s0. :type: int .. attribute:: n_components The number of components for non-s0 states. :type: int .. attribute:: PFs0 The parameters for the s0 distribution. :type: nn.Parameter .. py:attribute:: PFs0 :type: torch.nn.Parameter .. py:attribute:: _input_dim :type: int .. py:attribute:: _n_comp_max :type: int .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, 2)`. :returns: A tensor of shape `(*batch_shape, 1 + 5 * max_n_components)`. .. py:attribute:: n_components :type: int .. py:attribute:: n_components_s0 :type: int .. py:class:: BoxStateFlowModule(logZ_value, **kwargs) Bases: :py:obj:`gfn.utils.modules.MLP` A MLP for the state flow function of the Box environment. .. attribute:: logZ_value The log partition function value. :type: nn.Parameter .. py:method:: forward(preprocessed_states) Computes the forward pass of the neural network. :param preprocessed_states: The tensor states of shape `(*batch_shape, input_dim)`. :returns: A tensor of shape `(*batch_shape, output_dim)`. .. py:attribute:: logZ_value :type: torch.nn.Parameter .. py:class:: DistributionWrapper(states, delta, epsilon, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta, exit_probability, n_components, n_components_s0, debug = False) Bases: :py:obj:`torch.distributions.Distribution` A wrapper that combines `QuarterDisk` and `QuarterCircleWithExit`. .. attribute:: idx_is_initial The indices of the initial states. .. attribute:: idx_not_initial The indices of the non-initial states. .. attribute:: quarter_disk The `QuarterDisk` distribution. .. attribute:: quarter_circ The `QuarterCircleWithExit` distribution. .. py:attribute:: _output_shape :type: tuple[int, Ellipsis] .. py:attribute:: debug :value: False .. py:attribute:: idx_is_initial :type: torch.Tensor .. py:attribute:: idx_not_initial :type: torch.Tensor .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: A tensor of shape `(*batch_shape)` containing the log probabilities. .. py:attribute:: quarter_circ :type: Optional[QuarterCircleWithExit] .. py:attribute:: quarter_disk :type: Optional[QuarterDisk] .. py:method:: sample(sample_shape = Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: A tensor of shape `(sample_shape + self._output_shape)` containing the sampled actions. .. py:class:: QuarterCircle(delta, northeastern, centers, mixture_logits, alpha, beta, debug = False) Bases: :py:obj:`torch.distributions.Distribution` Represents distributions on quarter circles. The distributions are Mixture of Beta distributions on the possible angle range. When a state is of norm <= delta, and northeastern=False, then the distribution is a Dirac at the state (i.e. the only possible parent is s_0). Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py This is useful for the `Box` environment. .. attribute:: delta The radius of the quarter disk. .. attribute:: northeastern Whether the quarter disk is northeastern or southwestern. .. attribute:: n_states The number of states. .. attribute:: n_components The number of components in the mixture. .. attribute:: centers The centers of the distribution. .. attribute:: base_dist The base distribution. .. attribute:: min_angles The minimum angles. .. attribute:: max_angles The maximum angles. .. py:attribute:: base_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: centers :type: gfn.states.States .. py:attribute:: debug :value: False .. py:attribute:: delta :type: float .. py:method:: get_min_and_max_angles() Computes the minimum and maximum angles for the distribution. :returns: A tuple of two tensors of shape `(n_states,)` containing the minimum and maximum angles, respectively. .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: max_angles :type: torch.Tensor .. py:attribute:: min_angles :type: torch.Tensor .. py:attribute:: n_components :type: int .. py:attribute:: n_states :type: int .. py:attribute:: northeastern :type: bool .. py:method:: sample(sample_shape = Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: The sampled actions of shape `(sample_shape, n_states, 2)`. .. py:class:: QuarterCircleWithExit(delta, centers, exit_probability, mixture_logits, alpha, beta, epsilon = 0.0001) Bases: :py:obj:`torch.distributions.Distribution` Extends `QuarterCircle` with an exit action. When sampling, with probability `exit_probability`, the `exit_action` `[-inf, -inf]` is sampled. The `log_prob` function is adjusted accordingly. .. attribute:: delta The radius of the quarter disk. .. attribute:: epsilon The epsilon value to consider the state as being at the border of the square. .. attribute:: centers The centers of the distribution. .. attribute:: dist_without_exit The distribution without the exit action. .. attribute:: exit_probability The probability of exiting. .. attribute:: exit_action The exit action. .. attribute:: n_states The number of states. .. py:attribute:: centers :type: gfn.states.States .. py:attribute:: delta :type: float .. py:attribute:: dist_without_exit :type: QuarterCircle .. py:attribute:: epsilon :type: float .. py:attribute:: exit_action :type: torch.Tensor .. py:attribute:: exit_probability :type: torch.Tensor .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: n_states :type: int .. py:method:: sample() Samples from the distribution. :returns: The sampled actions with shape `(n_states, 2)`. .. py:class:: QuarterDisk(delta, mixture_logits, alpha_r, beta_r, alpha_theta, beta_theta) Bases: :py:obj:`torch.distributions.Distribution` Represents a distribution on the northeastern quarter disk. The radius and the angle follow Mixture of Betas distributions. Adapted from https://github.com/saleml/continuous-gfn/blob/master/sampling.py This is useful for the `Box` environment. .. attribute:: delta The radius of the quarter disk. .. attribute:: mixture_logits The logits of the mixture of Beta distributions. .. attribute:: base_r_dist The base distribution for the radius. .. attribute:: base_theta_dist The base distribution for the angle. .. attribute:: n_components The number of components in the mixture. .. py:attribute:: base_r_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: base_theta_dist :type: torch.distributions.MixtureSameFamily .. py:attribute:: debug :value: False .. py:attribute:: delta :type: float .. py:method:: log_prob(sampled_actions) Computes the log probability of the sampled actions. :param sampled_actions: Tensor of shape `(*batch_shape, 2)` with the actions to compute the log probability of. :returns: The log probability of the sampled actions as a tensor of shape `batch_shape`. .. py:attribute:: mixture_logits :type: torch.Tensor .. py:attribute:: n_components :type: int .. py:method:: sample(sample_shape = torch.Size()) Samples from the distribution. :param sample_shape: the shape of the samples to generate. :returns: The sampled actions of shape `(sample_shape, 2)`. .. py:class:: UniformBoxCartesianPBModule(n_components, n_dim = 2) Bases: :py:obj:`gfn.utils.modules.UniformModule` Fixed (non-learned) backward policy module for Cartesian Box. Backward-compatible alias for ``UniformModule``. Prefer ``BoxCartesianPBEstimator.uniform(env, n_components)`` for new code. .. py:function:: split_PF_module_output(output, n_comp_max) Splits the module output into the expected parameter sets. :param output: the module_output from the P_F model as a tensor of shape `(*batch_shape, output_dim)`. :param n_comp_max: the larger number of the two `n_components` and `n_components_s0`. :returns: - `exit_probability`: A probability unique to `QuarterCircleWithExit`. - `mixture_logits`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `alpha_r`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `beta_r`: Parameters shared by `QuarterDisk` and `QuarterCircleWithExit`. - `alpha_theta`: Parameters unique to `QuarterDisk`. - `beta_theta`: Parameters unique to `QuarterDisk`. :rtype: A tuple containing