gfn.utils.prob_calculations¶
Functions¶
|
Calculate PB log‑probabilities for trajectories. |
|
Calculate PF log‑probabilities for trajectories. |
|
Calculate PF and PB log‑probabilities for trajectories. |
|
Calculate PB log‑probabilities for transitions. |
|
Calculate PF log‑probabilities for transitions. |
|
Calculate PF and PB log‑probabilities for transitions. |
Module Contents¶
- gfn.utils.prob_calculations.get_trajectory_pbs(pb, trajectories, **policy_kwargs)¶
Calculate PB log‑probabilities for trajectories.
- Non‑vectorized (per‑step) evaluation with with alignment
(action at
twith state att+1) and mask~is_sink_state[t+1] & ~is_initial_state[t+1] & ~is_dummy[t] & ~is_exit[t]; skipt==0. is supported when specifically needed (estimator.is_vectorized=False).
- Parameters:
pb (gfn.estimators.Estimator | None) – Backward policy estimator, or
Nonefor trees (PB=1).trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
log_pbof shape(T, N).- Raises:
ValueError – If backward trajectories are provided.
- Return type:
torch.Tensor
- gfn.utils.prob_calculations.get_trajectory_pfs(pf, trajectories, recalculate_all_logprobs=True, **policy_kwargs)¶
Calculate PF log‑probabilities for trajectories.
Non‑vectorized (per‑step) evaluation with masks
~is_sink_state[t] & ~is_dummy[t]& no action‑id indexing is supported when specifically needed (estimator.is_vectorized=False).- Parameters:
pf (gfn.estimators.Estimator) – Forward policy estimator.
trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.
recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
log_pfof shape(T, N).- Raises:
ValueError – If backward trajectories are provided.
- Return type:
torch.Tensor
- gfn.utils.prob_calculations.get_trajectory_pfs_and_pbs(pf, pb, trajectories, recalculate_all_logprobs=True, **policy_kwargs)¶
Calculate PF and PB log‑probabilities for trajectories.
Delegates to
get_trajectory_pfsandget_trajectory_pbswhile forwarding policy kwargs.- Parameters:
pf (gfn.estimators.Estimator) – Forward policy estimator.
pb (gfn.estimators.Estimator | None) – Backward policy estimator, or
Nonefor trees (PB=1).trajectories (gfn.containers.Trajectories) – Trajectories to evaluate.
recalculate_all_logprobs (bool) – If True, recompute PF even if cached.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
(log_pf[T,N], log_pb[T,N])- Return type:
Tuple[torch.Tensor, torch.Tensor]
- gfn.utils.prob_calculations.get_transition_pbs(pb, transitions, **policy_kwargs)¶
Calculate PB log‑probabilities for transitions.
- Parameters:
pb (gfn.estimators.Estimator | None) – Backward policy estimator, or
Nonefor trees (PB=1).transitions (gfn.containers.Transitions) – Transitions to evaluate.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
log_pbof shape(M,).- Return type:
torch.Tensor
- gfn.utils.prob_calculations.get_transition_pfs(pf, transitions, recalculate_all_logprobs=True, **policy_kwargs)¶
Calculate PF log‑probabilities for transitions.
- Parameters:
pf (gfn.estimators.Estimator) – Forward policy estimator.
transitions (gfn.containers.Transitions) – Transitions to evaluate.
recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
log_pfof shape(M,).- Return type:
torch.Tensor
- gfn.utils.prob_calculations.get_transition_pfs_and_pbs(pf, pb, transitions, recalculate_all_logprobs=True, **policy_kwargs)¶
Calculate PF and PB log‑probabilities for transitions.
- Parameters:
pf (gfn.estimators.Estimator) – Forward policy estimator.
pb (gfn.estimators.Estimator | None) – Backward policy estimator, or
Nonefor trees (PB=1).transitions (gfn.containers.Transitions) – Transitions to evaluate.
recalculate_all_logprobs (bool) – If True, recompute PF even if cached. Useful for off-policy training.
**policy_kwargs (Any) – Extra kwargs for
to_probability_distribution.
- Returns:
(log_pf[M], log_pb[M]).- Raises:
ValueError – If backward transitions are provided.
- Return type:
Tuple[torch.Tensor, torch.Tensor]