gfn.gflownet.sub_trajectory_balance¶
Attributes¶
Classes¶
GFlowNet for the Sub-Trajectory Balance loss. |
Module Contents¶
- type gfn.gflownet.sub_trajectory_balance.ContributionsTensor = torch.Tensor¶
- type gfn.gflownet.sub_trajectory_balance.CumulativeLogProbsTensor = torch.Tensor¶
- type gfn.gflownet.sub_trajectory_balance.LogStateFlowsTensor = torch.Tensor¶
- type gfn.gflownet.sub_trajectory_balance.LogTrajectoriesTensor = torch.Tensor¶
- type gfn.gflownet.sub_trajectory_balance.MaskTensor = torch.Tensor¶
- type gfn.gflownet.sub_trajectory_balance.PredictionsTensor = torch.Tensor¶
- class gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet(pf, pb, logF, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-float('inf'), forward_looking=False, constant_pb=False, debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.TrajectoryBasedGFlowNetGFlowNet for the Sub-Trajectory Balance loss.
An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).
- Parameters:
pb (gfn.estimators.Estimator | None)
logF (gfn.estimators.ScalarEstimator | gfn.estimators.ConditionalScalarEstimator)
weighting (Literal['DB', 'ModifiedDB', 'TB', 'geometric', 'equal', 'geometric_within', 'equal_within'])
lamda (float)
log_reward_clip_min (float)
forward_looking (bool)
constant_pb (bool)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- pf¶
The forward policy estimator.
- pb¶
The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.
- logF¶
A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.
- weighting¶
The sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the
batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.
- “ModifiedDB”: Considers all one-step transitions of each trajectory
in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.
- “TB”: Considers only the full trajectory. Should be equivalent to
TrajectoryBalance loss.
- “equal_within”: Each sub-trajectory of each trajectory is weighed
equally within the trajectory. Then each trajectory is weighed equally within the batch.
- “equal”: Each sub-trajectory of each trajectory is weighed equally
within the set of all sub-trajectories.
- “geometric_within”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.
- “geometric”: Each sub-trajectory of each trajectory is weighed
proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.
- lamda¶
Discount factor for longer trajectories (used in geometric weighting).
- log_reward_clip_min¶
If finite, clips log rewards to this value.
- forward_looking¶
Whether to use the forward-looking GFN loss.
- constant_pb¶
Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.
- calculate_log_state_flows(env, trajectories, log_pf_trajectories)¶
Calculates log flows of each state in the trajectories.
- Parameters:
env (gfn.env.Env) – The environment object.
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
log_pf_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories.
- Returns:
A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories.
- Return type:
LogStateFlowsTensor
- calculate_masks(log_state_flows, trajectories)¶
Calculates masks indicating sink and terminal states.
- Parameters:
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the log flows of the states.
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size).
- Return type:
Tuple[MaskTensor, MaskTensor]
- calculate_preds(log_pf_traj_cum, log_state_flows, i)¶
Calculates the predictions tensor for the current sub-trajectory length.
- Parameters:
log_pf_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory.
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.
i (int) – The sub-trajectory length.
- Returns:
The predictions tensor of shape (max_length + 1 - i, batch_size).
- Return type:
PredictionsTensor
- calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards=None)¶
Calculates the targets tensor for the current sub-trajectory length.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
preds (PredictionsTensor) – Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length.
log_pb_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory.
log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.
is_terminal_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is terminal.
sink_states_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is a sink state.
i (int) – The sub-trajectory length.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The targets tensor of shape (max_length + 1 - i, batch_size).
- Return type:
TargetsTensor
- cumulative_logprobs(trajectories, log_p_trajectories)¶
Calculates the cumulative logprobs for all trajectories.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
log_p_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories.
- Returns:
A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory.
- Return type:
CumulativeLogProbsTensor
- forward_looking = False¶
- get_equal_contributions(trajectories)¶
Calculates contributions for the ‘equal’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_equal_within_contributions(trajectories)¶
Calculates contributions for the ‘equal_within’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_geometric_within_contributions(trajectories)¶
Calculates contributions for the ‘geometric_within’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_modified_db_contributions(trajectories)¶
Calculates contributions for the ‘ModifiedDB’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)¶
Computes sub-trajectory balance scores for all submitted trajectories.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories to evaluate.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
env (gfn.env.Env | None) – The environment where the trajectories are sampled from.
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
- scores: A list of tensors, each representing the scores of all
sub-trajectories of length k, for k in [1, …, max_length], where the score of a sub-trajectory \(\tau_{n:n+k} = (s_n, ..., s_{n+k})\) is \(\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})\). The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size).
- flattening_masks: A list of tensors indicating what should be masked out
from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.
- Return type:
A tuple (scores, flattening_masks)
- get_tb_contributions(trajectories)¶
Calculates contributions for the ‘TB’ weighting method.
- Parameters:
trajectories (gfn.containers.Trajectories) – The batch of trajectories.
- Returns:
The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).
- Return type:
ContributionsTensor
- lamda = 0.9¶
- logF¶
- logF_named_parameters()¶
Returns a dictionary of named parameters containing ‘logF’ in their name.
- Returns:
A dictionary of named parameters containing ‘logF’ in their name.
- Return type:
dict[str, torch.Tensor]
- logF_parameters()¶
Returns a list of parameters containing ‘logF’ in their name.
- Returns:
A list of parameters containing ‘logF’ in their name.
- Return type:
list[torch.Tensor]
- loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the sub-trajectory balance loss.
- Parameters:
env (gfn.env.Env) – The environment where the trajectories are sampled from.
trajectories (gfn.containers.Trajectories) – The batch of trajectories to compute the loss with.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Note: for geometric-based sub-trajectory weighting, ‘mean’ is not supported and is coerced to ‘sum’ (a warning is emitted when debug=True).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for
forward_looking=True, the state-flow computation may still depend onenv.log_reward(...), so customlog_rewardsdo not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- weighting = 'geometric_within'¶
- type gfn.gflownet.sub_trajectory_balance.TargetsTensor = torch.Tensor¶