gfn.samplers¶
Classes¶
Sampler equipped with local search capabilities. |
|
Estimator‑driven sampler for GFlowNet environments. |
Module Contents¶
- class gfn.samplers.LocalSearchSampler(pf_estimator, pb_estimator)¶
Bases:
SamplerSampler equipped with local search capabilities.
The LocalSearchSampler extends the basic Sampler with local search functionality based on the back-and-forth heuristic. This approach was first proposed by [Zhang et al. 2022](https://arxiv.org/abs/2202.01361) and further explored by [Kim et al. 2023](https://arxiv.org/abs/2310.02710).
The local search process involves: 1. Taking a trajectory and performing K backward steps using a backward policy 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion
- Parameters:
pf_estimator (gfn.estimators.Estimator)
pb_estimator (gfn.estimators.Estimator)
- estimator¶
The forward policy estimator (inherited from Sampler).
- backward_sampler¶
A Sampler instance with the backward policy estimator.
- static _combine_prev_and_recon_trajectories(n_prevs, prev_trajectories, recon_trajectories, prev_trajectories_log_pf=None, recon_trajectories_log_pf=None, prev_trajectories_log_pb=None, recon_trajectories_log_pb=None, debug=False)¶
Combines previous and reconstructed trajectories to create new trajectories.
This static method combines two trajectory segments: prev_trajectories and recon_trajectories to create new_trajectories. Specifically, new_trajectories is constructed by replacing certain portion of the prev_trajectories with recon_trajectories. See self.local_search for how to generate prev_trajectories and recon_trajectories.
- Parameters:
n_prevs (torch.Tensor) – Tensor indicating how many steps to take from prev_trajectories for each trajectory in the batch.
prev_trajectories (gfn.containers.Trajectories) – Trajectories obtained from backward sampling.
recon_trajectories (gfn.containers.Trajectories) – Trajectories obtained from forward reconstruction.
prev_trajectories_log_pf (torch.Tensor | None) – Optional log probabilities for forward policy on prev_trajectories.
recon_trajectories_log_pf (torch.Tensor | None) – Optional log probabilities for forward policy on recon_trajectories.
prev_trajectories_log_pb (torch.Tensor | None) – Optional log probabilities for backward policy on prev_trajectories.
recon_trajectories_log_pb (torch.Tensor | None) – Optional log probabilities for backward policy on recon_trajectories.
debug (bool) – If True, performs additional validation checks for debugging.
- Returns:
the new_trajectories Trajectories object with the combined trajectories
the new_trajectories_log_pf tensor of combined forward log probabilities
the new_trajectories_log_pb tensor of combined backward log probabilities
- Return type:
A tuple containing
Note
This method performs complex tensor operations to efficiently combine trajectory segments. The debug mode compares the vectorized approach with a for-loop implementation to ensure correctness.
- backward_sampler¶
- local_search(env, trajectories, save_estimator_outputs=False, save_logprobs=False, back_steps=None, back_ratio=None, use_metropolis_hastings=True, debug=False, **policy_kwargs)¶
Performs local search on a batch of trajectories.
This method implements the local search algorithm by: 1. For each trajectory, performing K backward steps to reach a junction state 2. Reconstructing the trajectory from the junction state using the forward policy 3. Optionally applying Metropolis-Hastings acceptance criterion to decide whether
to accept the new trajectory.
- Parameters:
env (gfn.env.Env) – The environment to sample trajectories from.
trajectories (gfn.containers.Trajectories) – The batch of trajectories to perform local search on.
save_estimator_outputs (bool) – If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies.
save_logprobs (bool) – If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training.
back_steps (torch.Tensor | None) – Number of backward steps to take. Must be provided if back_ratio is None.
back_ratio (float | None) – Ratio of trajectory length to use for backward steps. Must be provided if back_steps is None.
use_metropolis_hastings (bool) – If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards.
debug (bool) – If True, performs additional validation checks for debugging.
**policy_kwargs (Any) – Keyword arguments passed to the policy estimator. See sample_actions for details.
- Returns:
A Trajectories object refined by local search
A boolean tensor indicating which trajectories were updated
- Return type:
A tuple containing
- sample_trajectories(env, n=None, states=None, conditions=None, save_estimator_outputs=False, save_logprobs=False, n_local_search_loops=0, back_steps=None, back_ratio=None, use_metropolis_hastings=False, **policy_kwargs)¶
Samples trajectories with optional local search operations.
This method extends the basic trajectory sampling with local search operations. After sampling initial trajectories, it performs multiple rounds of local search to potentially improve the trajectory quality in terms of the reward.
- Parameters:
env (gfn.env.Env) – The environment to sample trajectories from.
n (Optional[int]) – Number of trajectories to sample, all starting from s0. Must be provided if states is None.
states (Optional[gfn.states.States]) – Initial states to start trajectories from. It should have batch_shape of length 1 (no trajectory dim). If None, n must be provided and we initialize n trajectories with the environment’s initial state.
conditions (Optional[torch.Tensor]) – Optional tensor of conditions information for conditional policies. Must match the batch shape of states.
save_estimator_outputs (bool) – If True, saves the estimator outputs for each step. Useful for off-policy training with tempered policies.
save_logprobs (bool) – If True, calculates and saves the log probabilities of sampled actions. Useful for on-policy training.
n_local_search_loops (int) – Number of local search loops to perform after initial sampling. Each loop creates additional trajectories.
back_steps (torch.Tensor | None) – Number of backward steps to take. Must be provided if back_ratio is None.
back_ratio (float | None) – Ratio of trajectory length to use for backward steps. Must be provided if back_steps is None.
use_metropolis_hastings (bool) – If True, applies Metropolis-Hastings acceptance criterion. If False, accepts new trajectories if they have higher rewards.
**policy_kwargs (Any) – Keyword arguments passed to the policy estimator. See sample_actions for details.
- Returns:
A Trajectories object representing the batch of sampled trajectories, where the number of trajectories is n * (1 + n_local_search_loops).
- Return type:
Note
The final trajectories container contains both the initial trajectories and the improved trajectories from local search.
- class gfn.samplers.Sampler(estimator)¶
Estimator‑driven sampler for GFlowNet environments.
The estimator builds action distributions, computes step log‑probs, and records artifacts into a rollout context via method flags. Direction (forward/backward) is determined by
estimator.is_backward.- Parameters:
estimator (gfn.estimators.Estimator)
- estimator¶
The underlying policy estimator. Must expose the methods contained in the PolicyMixin mixin.
- estimator¶
- sample_actions(env, states, conditions=None, save_estimator_outputs=False, save_logprobs=False, ctx=None, **policy_kwargs)¶
Sample one step from
statesvia the estimator.Initializes or reuses a rollout context with
estimator.init_context, builds a Distribution withestimator.compute_dist, and optionally computes log‑probs withestimator.log_probs. Per‑step artifacts are recorded by the estimator when the corresponding flags are set.- Parameters:
env (gfn.env.Env) – Environment providing action/state conversion utilities.
states (gfn.states.States) – Batch of states to act on.
conditions (torch.Tensor | None) – Optional condition vector for conditional policies.
save_estimator_outputs (bool) – If True, return the raw estimator outputs cached by the PolicyMixin for this step. Useful for off-policy training with tempered policies.
save_logprobs (bool) – If True, return per‑step log‑probs padded to batch. Useful for on-policy training.
**policy_kwargs (Any) – Extra kwargs forwarded to
to_probability_distribution.ctx (Any | None)
**policy_kwargs
- Returns:
(Actions, log_probs | None, estimator_outputs | None). The estimator outputs come fromPolicyMixin.get_current_estimator_output(ctx)when requested.- Return type:
Tuple[gfn.actions.Actions, torch.Tensor | None, torch.Tensor | None]
- sample_trajectories(env, n=None, states=None, conditions=None, save_estimator_outputs=False, save_logprobs=False, **policy_kwargs)¶
Roll out complete trajectories using the estimator.
Reuses a single rollout context across steps, calling
compute_dist&log_probseach iteration. Usesestimator.is_backwardto choose the environment step function.- Parameters:
env (gfn.env.Env) – Environment to sample in.
n (Optional[int]) – Number of trajectories if
statesis None.states (Optional[gfn.states.States]) – Starting states (batch shape length 1) or
None.conditions (Optional[torch.Tensor]) – Optional condition tensor for conditional environments, with shape (batch_size, condition_dim); each row is the condition vector for each trajectory.
save_estimator_outputs (bool) – If True, store per‑step estimator outputs. Useful for off-policy training with tempered policies.
save_logprobs (bool) – If True, store per‑step log‑probs. Useful for on-policy training.
**policy_kwargs (Any) – Extra kwargs forwarded to the policy.
- Returns:
A
Trajectorieswith stacked states/actions and any artifacts.- Return type:
Note
For backward trajectories, the reward is computed at the initial state (s0) rather than the terminal state (sf).