gfn.samplers ============ .. py:module:: gfn.samplers Classes ------- .. autoapisummary:: gfn.samplers.LocalSearchSampler gfn.samplers.Sampler Module Contents --------------- .. py:class:: LocalSearchSampler(pf_estimator, pb_estimator) Bases: :py:obj:`Sampler` Sampler equipped with local search capabilities. The LocalSearchSampler extends the basic Sampler with local search functionality based on the back-and-forth heuristic. This approach was first proposed by [Zhang et al. 2022](https://arxiv.org/abs/2202.01361) and further explored by [Kim et al. 2023](https://arxiv.org/abs/2310.02710). The local search process involves: 1. Taking a trajectory and performing K backward steps using a backward policy 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion .. attribute:: estimator The forward policy estimator (inherited from Sampler). .. attribute:: backward_sampler A Sampler instance with the backward policy estimator. .. py:method:: _combine_prev_and_recon_trajectories(n_prevs, prev_trajectories, recon_trajectories, prev_trajectories_log_pf = None, recon_trajectories_log_pf = None, prev_trajectories_log_pb = None, recon_trajectories_log_pb = None, debug = False) :staticmethod: Combines previous and reconstructed trajectories to create new trajectories. This static method combines two trajectory segments: `prev_trajectories` and `recon_trajectories` to create `new_trajectories`. Specifically, `new_trajectories` is constructed by replacing certain portion of the `prev_trajectories` with `recon_trajectories`. See self.local_search for how to generate `prev_trajectories` and `recon_trajectories`. :param n_prevs: Tensor indicating how many steps to take from prev_trajectories for each trajectory in the batch. :param prev_trajectories: Trajectories obtained from backward sampling. :param recon_trajectories: Trajectories obtained from forward reconstruction. :param prev_trajectories_log_pf: Optional log probabilities for forward policy on `prev_trajectories`. :param recon_trajectories_log_pf: Optional log probabilities for forward policy on `recon_trajectories`. :param prev_trajectories_log_pb: Optional log probabilities for backward policy on `prev_trajectories`. :param recon_trajectories_log_pb: Optional log probabilities for backward policy on `recon_trajectories`. :param debug: If True, performs additional validation checks for debugging. :returns: - the `new_trajectories` Trajectories object with the combined trajectories - the `new_trajectories_log_pf` tensor of combined forward log probabilities - the `new_trajectories_log_pb` tensor of combined backward log probabilities :rtype: A tuple containing .. note:: This method performs complex tensor operations to efficiently combine trajectory segments. The debug mode compares the vectorized approach with a for-loop implementation to ensure correctness. .. py:attribute:: backward_sampler .. py:method:: local_search(env, trajectories, save_estimator_outputs = False, save_logprobs = False, back_steps = None, back_ratio = None, use_metropolis_hastings = True, debug = False, **policy_kwargs) Performs local search on a batch of trajectories. This method implements the local search algorithm by: 1. For each trajectory, performing K backward steps to reach a junction state 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion to decide whether to accept the new trajectory. :param env: The environment to sample trajectories from. :param trajectories: The batch of trajectories to perform local search on. :param save_estimator_outputs: If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies. :param save_logprobs: If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training. :param back_steps: Number of backward steps to take. Must be provided if `back_ratio` is None. :param back_ratio: Ratio of trajectory length to use for backward steps. Must be provided if `back_steps` is None. :param use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards. :param debug: If True, performs additional validation checks for debugging. :param \*\*policy_kwargs: Keyword arguments passed to the policy estimator. See `sample_actions` for details. :returns: - A Trajectories object refined by local search - A boolean tensor indicating which trajectories were updated :rtype: A tuple containing .. py:method:: sample_trajectories(env, n = None, states = None, conditions = None, save_estimator_outputs = False, save_logprobs = False, n_local_search_loops = 0, back_steps = None, back_ratio = None, use_metropolis_hastings = False, **policy_kwargs) Samples trajectories with optional local search operations. This method extends the basic trajectory sampling with local search operations. After sampling initial trajectories, it performs multiple rounds of local search to potentially improve the trajectory quality in terms of the reward. :param env: The environment to sample trajectories from. :param n: Number of trajectories to sample, all starting from s0. Must be provided if `states` is None. :param states: Initial states to start trajectories from. It should have batch_shape of length 1 (no trajectory dim). If `None`, `n` must be provided and we initialize `n` trajectories with the environment's initial state. :param conditions: Optional tensor of conditions information for conditional policies. Must match the batch shape of states. :param save_estimator_outputs: If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies. :param save_logprobs: If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training. :param n_local_search_loops: Number of local search loops to perform after initial sampling. Each loop creates additional trajectories. :param back_steps: Number of backward steps to take. Must be provided if `back_ratio` is None. :param back_ratio: Ratio of trajectory length to use for backward steps. Must be provided if `back_steps` is None. :param use_metropolis_hastings: If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards. :param \*\*policy_kwargs: Keyword arguments passed to the policy estimator. See `sample_actions` for details. :returns: A Trajectories object representing the batch of sampled trajectories, where the number of trajectories is n * (1 + n_local_search_loops). .. note:: The final trajectories container contains both the initial trajectories and the improved trajectories from local search. .. py:class:: Sampler(estimator) Estimator‑driven sampler for GFlowNet environments. The estimator builds action distributions, computes step log‑probs, and records artifacts into a rollout context via method flags. Direction (forward/backward) is determined by ``estimator.is_backward``. .. attribute:: estimator The underlying policy estimator. Must expose the methods contained in the `PolicyMixin` mixin. .. py:attribute:: estimator .. py:method:: sample_actions(env, states, conditions = None, save_estimator_outputs = False, save_logprobs = False, ctx = None, **policy_kwargs) Sample one step from ``states`` via the estimator. Initializes or reuses a rollout context with ``estimator.init_context``, builds a Distribution with ``estimator.compute_dist``, and optionally computes log‑probs with ``estimator.log_probs``. Per‑step artifacts are recorded by the estimator when the corresponding flags are set. :param env: Environment providing action/state conversion utilities. :param states: Batch of states to act on. :param conditions: Optional condition vector for conditional policies. :param save_estimator_outputs: If True, return the raw estimator outputs cached by the PolicyMixin for this step. Useful for off-policy training with tempered policies. :param save_logprobs: If True, return per‑step log‑probs padded to batch. Useful for on-policy training. :param \*\*policy_kwargs: Extra kwargs forwarded to ``to_probability_distribution``. :returns: ``(Actions, log_probs | None, estimator_outputs | None)``. The estimator outputs come from ``PolicyMixin.get_current_estimator_output(ctx)`` when requested. .. py:method:: sample_trajectories(env, n = None, states = None, conditions = None, save_estimator_outputs = False, save_logprobs = False, **policy_kwargs) Roll out complete trajectories using the estimator. Reuses a single rollout context across steps, calling ``compute_dist`` & ``log_probs`` each iteration. Uses ``estimator.is_backward`` to choose the environment step function. :param env: Environment to sample in. :param n: Number of trajectories if ``states`` is None. :param states: Starting states (batch shape length 1) or ``None``. :param conditions: Optional condition tensor for conditional environments, with shape (batch_size, condition_dim); each row is the condition vector for each trajectory. :param save_estimator_outputs: If True, store per‑step estimator outputs. Useful for off-policy training with tempered policies. :param save_logprobs: If True, store per‑step log‑probs. Useful for on-policy training. :param \*\*policy_kwargs: Extra kwargs forwarded to the policy. :returns: A ``Trajectories`` with stacked states/actions and any artifacts. .. note:: For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf).