gfn.utils.prob_calculations

Functions

get_trajectory_pbs(pb, trajectories, **policy_kwargs)

Calculate PB log‑probabilities for trajectories.

get_trajectory_pfs(pf, trajectories[, ...])

Calculate PF log‑probabilities for trajectories.

get_trajectory_pfs_and_pbs(pf, pb, trajectories[, ...])

Calculate PF and PB log‑probabilities for trajectories.

get_transition_pbs(pb, transitions, **policy_kwargs)

Calculate PB log‑probabilities for transitions.

get_transition_pfs(pf, transitions[, ...])

Calculate PF log‑probabilities for transitions.

get_transition_pfs_and_pbs(pf, pb, transitions[, ...])

Calculate PF and PB log‑probabilities for transitions.

Module Contents

gfn.utils.prob_calculations.get_trajectory_pbs(pb, trajectories, **policy_kwargs)

Calculate PB log‑probabilities for trajectories.

Non‑vectorized (per‑step) evaluation with with alignment

(action at t with state at t+1) and mask ~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]; skip t==0. is supported when specifically needed (estimator.is_vectorized=False).

Parameters:
Returns:

log_pb of shape (T, N).

Raises:

ValueError – If backward trajectories are provided.

Return type:

torch.Tensor

gfn.utils.prob_calculations.get_trajectory_pfs(pf, trajectories, recalculate_all_logprobs=True, **policy_kwargs)

Calculate PF log‑probabilities for trajectories.

Non‑vectorized (per‑step) evaluation with masks ~is_sink_state[t] & ~is_dummy[t] & no action‑id indexing is supported when specifically needed (estimator.is_vectorized=False).

Parameters:
  • pf (gfn.estimators.Estimator) – Forward policy estimator.

  • trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.

  • recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.

  • **policy_kwargs (Any) – Extra kwargs for to_probability_distribution.

Returns:

log_pf of shape (T, N).

Raises:

ValueError – If backward trajectories are provided.

Return type:

torch.Tensor

gfn.utils.prob_calculations.get_trajectory_pfs_and_pbs(pf, pb, trajectories, recalculate_all_logprobs=True, **policy_kwargs)

Calculate PF and PB log‑probabilities for trajectories.

Delegates to get_trajectory_pfs and get_trajectory_pbs while forwarding policy kwargs.

Parameters:
Returns:

(log_pf[T,N], log_pb[T,N])

Return type:

Tuple[torch.Tensor, torch.Tensor]

gfn.utils.prob_calculations.get_transition_pbs(pb, transitions, **policy_kwargs)

Calculate PB log‑probabilities for transitions.

Parameters:
Returns:

log_pb of shape (M,).

Return type:

torch.Tensor

gfn.utils.prob_calculations.get_transition_pfs(pf, transitions, recalculate_all_logprobs=True, **policy_kwargs)

Calculate PF log‑probabilities for transitions.

Parameters:
  • pf (gfn.estimators.Estimator) – Forward policy estimator.

  • transitions (gfn.containers.Transitions) – Transitions to evaluate.

  • recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.

  • **policy_kwargs (Any) – Extra kwargs for to_probability_distribution.

Returns:

log_pf of shape (M,).

Return type:

torch.Tensor

gfn.utils.prob_calculations.get_transition_pfs_and_pbs(pf, pb, transitions, recalculate_all_logprobs=True, **policy_kwargs)

Calculate PF and PB log‑probabilities for transitions.

Parameters:
  • pf (gfn.estimators.Estimator) – Forward policy estimator.

  • pb (gfn.estimators.Estimator | None) – Backward policy estimator, or None for trees (PB=1).

  • transitions (gfn.containers.Transitions) – Transitions to evaluate.

  • recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.

  • **policy_kwargs (Any) – Extra kwargs for to_probability_distribution.

Returns:

(log_pf[M], log_pb[M]).

Raises:

ValueError – If backward transitions are provided.

Return type:

Tuple[torch.Tensor, torch.Tensor]