gfn.gflownet.sub_trajectory_balance =================================== .. py:module:: gfn.gflownet.sub_trajectory_balance Attributes ---------- .. autoapisummary:: gfn.gflownet.sub_trajectory_balance.ContributionsTensor gfn.gflownet.sub_trajectory_balance.CumulativeLogProbsTensor gfn.gflownet.sub_trajectory_balance.LogStateFlowsTensor gfn.gflownet.sub_trajectory_balance.LogTrajectoriesTensor gfn.gflownet.sub_trajectory_balance.MaskTensor gfn.gflownet.sub_trajectory_balance.PredictionsTensor gfn.gflownet.sub_trajectory_balance.TargetsTensor Classes ------- .. autoapisummary:: gfn.gflownet.sub_trajectory_balance.SubTBGFlowNet Module Contents --------------- .. py:type:: ContributionsTensor :canonical: torch.Tensor .. py:type:: CumulativeLogProbsTensor :canonical: torch.Tensor .. py:type:: LogStateFlowsTensor :canonical: torch.Tensor .. py:type:: LogTrajectoriesTensor :canonical: torch.Tensor .. py:type:: MaskTensor :canonical: torch.Tensor .. py:type:: PredictionsTensor :canonical: torch.Tensor .. py:class:: SubTBGFlowNet(pf, pb, logF, weighting = 'geometric_within', lamda = 0.9, log_reward_clip_min = -float('inf'), forward_looking = False, constant_pb = False, debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.TrajectoryBasedGFlowNet` GFlowNet for the Sub-Trajectory Balance loss. An implementation of the sub-trajectory balance loss as described in [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782). .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: logF A ScalarEstimator or ConditionalScalarEstimator for estimating the log flow of the states. .. attribute:: weighting The sub-trajectories weighting scheme. - "DB": Considers all one-step transitions of each trajectory in the batch and weighs them equally (regardless of the length of trajectory). Should be equivalent to DetailedBalance loss. - "ModifiedDB": Considers all one-step transitions of each trajectory in the batch and weighs them inversely proportional to the trajectory length. This ensures that the loss is not dominated by long trajectories. Each trajectory contributes equally to the loss. - "TB": Considers only the full trajectory. Should be equivalent to TrajectoryBalance loss. - "equal_within": Each sub-trajectory of each trajectory is weighed equally within the trajectory. Then each trajectory is weighed equally within the batch. - "equal": Each sub-trajectory of each trajectory is weighed equally within the set of all sub-trajectories. - "geometric_within": Each sub-trajectory of each trajectory is weighed proportionally to (lamda ** len(sub_trajectory)), within each trajectory. THIS CORRESPONDS TO THE ONE IN THE PAPER. - "geometric": Each sub-trajectory of each trajectory is weighed proportionally to (lamda ** len(sub_trajectory)), within the set of all sub-trajectories. .. attribute:: lamda Discount factor for longer trajectories (used in geometric weighting). .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. attribute:: forward_looking Whether to use the forward-looking GFN loss. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. py:method:: calculate_log_state_flows(env, trajectories, log_pf_trajectories) Calculates log flows of each state in the trajectories. :param env: The environment object. :param trajectories: The batch of trajectories. :param log_pf_trajectories: Tensor of shape (max_length, batch_size) containing the logprobs of the forward actions of the trajectories. :returns: A tensor of shape (max_length, batch_size) containing the log flows of each state in the trajectories. .. py:method:: calculate_masks(log_state_flows, trajectories) Calculates masks indicating sink and terminal states. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the log flows of the states. :param trajectories: The batch of trajectories. :returns: A tuple of two mask tensors (sink_states_mask, is_terminal_mask), each of shape (max_length, batch_size). .. py:method:: calculate_preds(log_pf_traj_cum, log_state_flows, i) Calculates the predictions tensor for the current sub-trajectory length. :param log_pf_traj_cum: Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the forward actions for each trajectory. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the estimated log flow of the states. :param i: The sub-trajectory length. :returns: The predictions tensor of shape (max_length + 1 - i, batch_size). .. py:method:: calculate_targets(trajectories, preds, log_pb_traj_cum, log_state_flows, is_terminal_mask, sink_states_mask, i, log_rewards = None) Calculates the targets tensor for the current sub-trajectory length. :param trajectories: The batch of trajectories. :param preds: Tensor of shape (max_length + 1 - i, batch_size) containing the predictions for the current sub-trajectory length. :param log_pb_traj_cum: Tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs of the backward actions for each trajectory. :param log_state_flows: Tensor of shape (max_length, batch_size) containing the estimated log flow of the states. :param is_terminal_mask: A mask of shape (max_length, batch_size) indicating whether the state is terminal. :param sink_states_mask: A mask of shape (max_length, batch_size) indicating whether the state is a sink state. :param i: The sub-trajectory length. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The targets tensor of shape (max_length + 1 - i, batch_size). .. py:method:: cumulative_logprobs(trajectories, log_p_trajectories) Calculates the cumulative logprobs for all trajectories. :param trajectories: The batch of trajectories. :param log_p_trajectories: Tensor of shape (max_length, batch_size) containing the logprobs of the forward or backward actions of the trajectories. :returns: A tensor of shape (max_length + 1, batch_size) containing the cumulative sum of logprobs for each trajectory. .. py:attribute:: forward_looking :value: False .. py:method:: get_equal_contributions(trajectories) Calculates contributions for the 'equal' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_equal_within_contributions(trajectories) Calculates contributions for the 'equal_within' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_geometric_within_contributions(trajectories) Calculates contributions for the 'geometric_within' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_modified_db_contributions(trajectories) Calculates contributions for the 'ModifiedDB' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) Computes sub-trajectory balance scores for all submitted trajectories. :param trajectories: The batch of trajectories to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param env: The environment where the trajectories are sampled from. :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: - scores: A list of tensors, each representing the scores of all sub-trajectories of length k, for k in [1, ..., max_length], where the score of a sub-trajectory $\tau_{n:n+k} = (s_n, ..., s_{n+k})$ is $\log P_F(\tau_{n:n+k}) + \log F(s_n) - \log P_B(\tau_{n:n+k}) - \log F(s_{n+k})$. The shape of each score from k-length sub-trajectory is (max_length - k + 1, batch_size). - flattening_masks: A list of tensors indicating what should be masked out from the each element of the first list (scores), given that not all sub-trajectories of length k exist for each trajectory. The entries of those tensors are True if the corresponding sub-trajectory does not exist. :rtype: A tuple (scores, flattening_masks) .. py:method:: get_tb_contributions(trajectories) Calculates contributions for the 'TB' weighting method. :param trajectories: The batch of trajectories. :returns: The contributions tensor of shape (max_len * (max_len+1) / 2, batch_size). .. py:attribute:: lamda :value: 0.9 .. py:attribute:: logF .. py:method:: logF_named_parameters() Returns a dictionary of named parameters containing 'logF' in their name. :returns: A dictionary of named parameters containing 'logF' in their name. .. py:method:: logF_parameters() Returns a list of parameters containing 'logF' in their name. :returns: A list of parameters containing 'logF' in their name. .. py:method:: loss(env, trajectories, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the sub-trajectory balance loss. :param env: The environment where the trajectories are sampled from. :param trajectories: The batch of trajectories to compute the loss with. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). Note: for geometric-based sub-trajectory weighting, 'mean' is not supported and is coerced to 'sum' (a warning is emitted when debug=True). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards. When provided, this overrides the terminal reward term used by the loss. In particular, for ``forward_looking=True``, the state-flow computation may still depend on ``env.log_reward(...)``, so custom ``log_rewards`` do not fully replace environment rewards in that mode. Useful for intrinsic rewards affecting the terminal boundary term (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed sub-trajectory balance loss as a tensor. The shape depends on the reduction method. .. py:attribute:: weighting :value: 'geometric_within' .. py:type:: TargetsTensor :canonical: torch.Tensor