gfn.samplers

Classes

LocalSearchSampler

Sampler equipped with local search capabilities.

Sampler

Estimator‑driven sampler for GFlowNet environments.

Module Contents

class gfn.samplers.LocalSearchSampler(pf_estimator, pb_estimator)

Bases: Sampler

Sampler equipped with local search capabilities.

The LocalSearchSampler extends the basic Sampler with local search functionality based on the back-and-forth heuristic. This approach was first proposed by [Zhang et al. 2022](https://arxiv.org/abs/2202.01361) and further explored by [Kim et al. 2023](https://arxiv.org/abs/2310.02710).

The local search process involves: 1. Taking a trajectory and performing K backward steps using a backward policy 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion

Parameters:
estimator

The forward policy estimator (inherited from Sampler).

backward_sampler

A Sampler instance with the backward policy estimator.

static _combine_prev_and_recon_trajectories(n_prevs, prev_trajectories, recon_trajectories, prev_trajectories_log_pf=None, recon_trajectories_log_pf=None, prev_trajectories_log_pb=None, recon_trajectories_log_pb=None, debug=False)

Combines previous and reconstructed trajectories to create new trajectories.

This static method combines two trajectory segments: prev_trajectories and recon_trajectories to create new_trajectories. Specifically, new_trajectories is constructed by replacing certain portion of the prev_trajectories with recon_trajectories. See self.local_search for how to generate prev_trajectories and recon_trajectories.

Parameters:
  • n_prevs (torch.Tensor) – Tensor indicating how many steps to take from prev_trajectories for each trajectory in the batch.

  • prev_trajectories (gfn.containers.Trajectories) – Trajectories obtained from backward sampling.

  • recon_trajectories (gfn.containers.Trajectories) – Trajectories obtained from forward reconstruction.

  • prev_trajectories_log_pf (torch.Tensor | None) – Optional log probabilities for forward policy on prev_trajectories.

  • recon_trajectories_log_pf (torch.Tensor | None) – Optional log probabilities for forward policy on recon_trajectories.

  • prev_trajectories_log_pb (torch.Tensor | None) – Optional log probabilities for backward policy on prev_trajectories.

  • recon_trajectories_log_pb (torch.Tensor | None) – Optional log probabilities for backward policy on recon_trajectories.

  • debug (bool) – If True, performs additional validation checks for debugging.

Returns:

  • the new_trajectories Trajectories object with the combined trajectories

  • the new_trajectories_log_pf tensor of combined forward log probabilities

  • the new_trajectories_log_pb tensor of combined backward log probabilities

Return type:

A tuple containing

Note

This method performs complex tensor operations to efficiently combine trajectory segments. The debug mode compares the vectorized approach with a for-loop implementation to ensure correctness.

backward_sampler

Performs local search on a batch of trajectories.

This method implements the local search algorithm by: 1. For each trajectory, performing K backward steps to reach a junction state 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion to decide whether

to accept the new trajectory.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories to perform local search on.

  • save_estimator_outputs (bool) – If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies.

  • save_logprobs (bool) – If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training.

  • back_steps (torch.Tensor | None) – Number of backward steps to take. Must be provided if back_ratio is None.

  • back_ratio (float | None) – Ratio of trajectory length to use for backward steps. Must be provided if back_steps is None.

  • use_metropolis_hastings (bool) – If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards.

  • debug (bool) – If True, performs additional validation checks for debugging.

  • **policy_kwargs (Any) – Keyword arguments passed to the policy estimator. See sample_actions for details.

Returns:

  • A Trajectories object refined by local search

  • A boolean tensor indicating which trajectories were updated

Return type:

A tuple containing

sample_trajectories(env, n=None, states=None, conditions=None, save_estimator_outputs=False, save_logprobs=False, n_local_search_loops=0, back_steps=None, back_ratio=None, use_metropolis_hastings=False, **policy_kwargs)

Samples trajectories with optional local search operations.

This method extends the basic trajectory sampling with local search operations. After sampling initial trajectories, it performs multiple rounds of local search to potentially improve the trajectory quality in terms of the reward.

Parameters:
  • env (gfn.env.Env) – The environment to sample trajectories from.

  • n (Optional[int]) – Number of trajectories to sample, all starting from s0. Must be provided if states is None.

  • states (Optional[gfn.states.States]) – Initial states to start trajectories from. It should have batch_shape of length 1 (no trajectory dim). If None, n must be provided and we initialize n trajectories with the environment’s initial state.

  • conditions (Optional[torch.Tensor]) – Optional tensor of conditions information for conditional policies. Must match the batch shape of states.

  • save_estimator_outputs (bool) – If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies.

  • save_logprobs (bool) – If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training.

  • n_local_search_loops (int) – Number of local search loops to perform after initial sampling. Each loop creates additional trajectories.

  • back_steps (torch.Tensor | None) – Number of backward steps to take. Must be provided if back_ratio is None.

  • back_ratio (float | None) – Ratio of trajectory length to use for backward steps. Must be provided if back_steps is None.

  • use_metropolis_hastings (bool) – If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards.

  • **policy_kwargs (Any) – Keyword arguments passed to the policy estimator. See sample_actions for details.

Returns:

A Trajectories object representing the batch of sampled trajectories, where the number of trajectories is n * (1 + n_local_search_loops).

Return type:

gfn.containers.Trajectories

Note

The final trajectories container contains both the initial trajectories and the improved trajectories from local search.

class gfn.samplers.Sampler(estimator)

Estimator‑driven sampler for GFlowNet environments.

The estimator builds action distributions, computes step log‑probs, and records artifacts into a rollout context via method flags. Direction (forward/backward) is determined by estimator.is_backward.

Parameters:

estimator (gfn.estimators.Estimator)

estimator

The underlying policy estimator. Must expose the methods contained in the PolicyMixin mixin.

estimator
sample_actions(env, states, conditions=None, save_estimator_outputs=False, save_logprobs=False, ctx=None, **policy_kwargs)

Sample one step from states via the estimator.

Initializes or reuses a rollout context with estimator.init_context, builds a Distribution with estimator.compute_dist, and optionally computes log‑probs with estimator.log_probs. Per‑step artifacts are recorded by the estimator when the corresponding flags are set.

Parameters:
  • env (gfn.env.Env) – Environment providing action/state conversion utilities.

  • states (gfn.states.States) – Batch of states to act on.

  • conditions (torch.Tensor | None) – Optional condition vector for conditional policies.

  • save_estimator_outputs (bool) – If True, return the raw estimator outputs cached by the PolicyMixin for this step. Useful for off-policy training with tempered policies.

  • save_logprobs (bool) – If True, return per‑step log‑probs padded to batch. Useful for on-policy training.

  • **policy_kwargs (Any) – Extra kwargs forwarded to to_probability_distribution.

  • ctx (Any | None)

  • **policy_kwargs

Returns:

(Actions, log_probs | None, estimator_outputs | None). The estimator outputs come from PolicyMixin.get_current_estimator_output(ctx) when requested.

Return type:

Tuple[gfn.actions.Actions, torch.Tensor | None, torch.Tensor | None]

sample_trajectories(env, n=None, states=None, conditions=None, save_estimator_outputs=False, save_logprobs=False, **policy_kwargs)

Roll out complete trajectories using the estimator.

Reuses a single rollout context across steps, calling compute_dist & log_probs each iteration. Uses estimator.is_backward to choose the environment step function.

Parameters:
  • env (gfn.env.Env) – Environment to sample in.

  • n (Optional[int]) – Number of trajectories if states is None.

  • states (Optional[gfn.states.States]) – Starting states (batch shape length 1) or None.

  • conditions (Optional[torch.Tensor]) – Optional condition tensor for conditional environments, with shape (batch_size, condition_dim); each row is the condition vector for each trajectory.

  • save_estimator_outputs (bool) – If True, store per‑step estimator outputs. Useful for off-policy training with tempered policies.

  • save_logprobs (bool) – If True, store per‑step log‑probs. Useful for on-policy training.

  • **policy_kwargs (Any) – Extra kwargs forwarded to the policy.

Returns:

A Trajectories with stacked states/actions and any artifacts.

Return type:

gfn.containers.Trajectories

Note

For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf).