gfn.gflownet.sub_trajectory_balance

Attributes

ContributionsTensor

CumulativeLogProbsTensor

LogStateFlowsTensor

LogTrajectoriesTensor

MaskTensor

PredictionsTensor

TargetsTensor

Classes

SubTBGFlowNet

GFlowNet for the Sub-Trajectory Balance loss.

Module Contents

type gfn.gflownet.sub_trajectory_balance.ContributionsTensor = torch.Tensor
type gfn.gflownet.sub_trajectory_balance.CumulativeLogProbsTensor = torch.Tensor
type gfn.gflownet.sub_trajectory_balance.LogStateFlowsTensor = torch.Tensor
type gfn.gflownet.sub_trajectory_balance.LogTrajectoriesTensor = torch.Tensor
type gfn.gflownet.sub_trajectory_balance.MaskTensor = torch.Tensor
type gfn.gflownet.sub_trajectory_balance.PredictionsTensor = torch.Tensor
class gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet(pf, pb, logF, weighting='geometric_within', lamda=0.9, log_reward_clip_min=-float('inf'), forward_looking=False, constant_pb=False, debug=False, loss_fn=None)

Bases: gfn.gflownet.base.TrajectoryBasedGFlowNet

GFlowNet for the Sub-Trajectory Balance loss.

An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782).

Parameters:
pf

The forward policy estimator.

pb

The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1.

logF

A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states.

weighting

The sub-trajectories weighting scheme. - “DB”: Considers all one-step transitions of each trajectory in the

batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss.

  • “ModifiedDB”: Considers all one-step transitions of each trajectory

    in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss.

  • “TB”: Considers only the full trajectory. Should be equivalent to

    TrajectoryBalance loss.

  • “equal_within”: Each sub-trajectory of each trajectory is weighed

    equally within the trajectory. Then each trajectory is weighed equally within the batch.

  • “equal”: Each sub-trajectory of each trajectory is weighed equally

    within the set of all sub-trajectories.

  • “geometric_within”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER.

  • “geometric”: Each sub-trajectory of each trajectory is weighed

    proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories.

lamda

Discount factor for longer trajectories (used in geometric weighting).

log_reward_clip_min

If finite, clips log rewards to this value.

forward_looking

Whether to use the forward-looking GFN loss.

constant_pb

Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1.

calculate_log_state_flows(env, trajectories, log_pf_trajectories)

Calculates log flows of each state in the trajectories.

Parameters:
  • env (gfn.env.Env) – The environment object.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • log_pf_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories.

Returns:

A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories.

Return type:

LogStateFlowsTensor

calculate_masks(log_state_flows, trajectories)

Calculates masks indicating sink and terminal states.

Parameters:
  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the log flows of the states.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size).

Return type:

Tuple[MaskTensor, MaskTensor]

calculate_preds(log_pf_traj_cum, log_state_flows, i)

Calculates the predictions tensor for the current sub-trajectory length.

Parameters:
  • log_pf_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory.

  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.

  • i (int) – The sub-trajectory length.

Returns:

The predictions tensor of shape (max_length + 1 - i, batch_size).

Return type:

PredictionsTensor

calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards=None)

Calculates the targets tensor for the current sub-trajectory length.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • preds (PredictionsTensor) – Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length.

  • log_pb_traj_cum (CumulativeLogProbsTensor) – Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory.

  • log_state_flows (LogStateFlowsTensor) – Tensor of shape (max_length, batch_size) containing the estimated log flow of the states.

  • is_terminal_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is terminal.

  • sink_states_mask (MaskTensor) – A mask of shape (max_length, batch_size) indicating whether the state is a sink state.

  • i (int) – The sub-trajectory length.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The targets tensor of shape (max_length + 1 - i, batch_size).

Return type:

TargetsTensor

cumulative_logprobs(trajectories, log_p_trajectories)

Calculates the cumulative logprobs for all trajectories.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories.

  • log_p_trajectories (LogTrajectoriesTensor) – Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories.

Returns:

A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory.

Return type:

CumulativeLogProbsTensor

forward_looking = False
get_equal_contributions(trajectories)

Calculates contributions for the ‘equal’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_equal_within_contributions(trajectories)

Calculates contributions for the ‘equal_within’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_geometric_within_contributions(trajectories)

Calculates contributions for the ‘geometric_within’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_modified_db_contributions(trajectories)

Calculates contributions for the ‘ModifiedDB’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

get_scores(trajectories, recalculate_all_logprobs=True, env=None, *, log_rewards=None)

Computes sub-trajectory balance scores for all submitted trajectories.

Parameters:
  • trajectories (gfn.containers.Trajectories) – The batch of trajectories to evaluate.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • env (gfn.env.Env | None) – The environment where the trajectories are sampled from.

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

  • scores: A list of tensors, each representing the scores of all

    sub-trajectories of length k, for k in [1, …, max_length], where the score of a sub-trajectory \(\tau_{n:n+k} = (s_n, ..., s_{n+k})\) is \(\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})\). The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size).

  • flattening_masks: A list of tensors indicating what should be masked out

    from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist.

Return type:

A tuple (scores, flattening_masks)

get_tb_contributions(trajectories)

Calculates contributions for the ‘TB’ weighting method.

Parameters:

trajectories (gfn.containers.Trajectories) – The batch of trajectories.

Returns:

The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size).

Return type:

ContributionsTensor

lamda = 0.9
logF
logF_named_parameters()

Returns a dictionary of named parameters containing ‘logF’ in their name.

Returns:

A dictionary of named parameters containing ‘logF’ in their name.

Return type:

dict[str, torch.Tensor]

logF_parameters()

Returns a list of parameters containing ‘logF’ in their name.

Returns:

A list of parameters containing ‘logF’ in their name.

Return type:

list[torch.Tensor]

loss(env, trajectories, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the sub-trajectory balance loss.

Parameters:
  • env (gfn.env.Env) – The environment where the trajectories are sampled from.

  • trajectories (gfn.containers.Trajectories) – The batch of trajectories to compute the loss with.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’). Note: for geometric-based sub-trajectory weighting, ‘mean’ is not supported and is coerced to ‘sum’ (a warning is emitted when debug=True).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for forward_looking=True, the state-flow computation may still depend on env.log_reward(...), so custom log_rewards do not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

weighting = 'geometric_within'
type gfn.gflownet.sub_trajectory_balance.TargetsTensor = torch.Tensor