gfn.gflownet.base ================= .. py:module:: gfn.gflownet.base Attributes ---------- .. autoapisummary:: gfn.gflownet.base.TrainingSampleType Classes ------- .. autoapisummary:: gfn.gflownet.base.GFlowNet gfn.gflownet.base.PFBasedGFlowNet gfn.gflownet.base.TrajectoryBasedGFlowNet Functions --------- .. autoapisummary:: gfn.gflownet.base.loss_reduce Module Contents --------------- .. py:class:: GFlowNet(debug = False, loss_fn = None) Bases: :py:obj:`abc.ABC`, :py:obj:`torch.nn.Module`, :py:obj:`Generic`\ [\ :py:obj:`TrainingSampleType`\ ] Abstract base class for GFlowNets. A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). .. py:method:: assert_finite_gradients() Asserts that the gradients are finite. .. py:method:: assert_finite_parameters() Asserts that the parameters are finite. .. py:attribute:: debug :value: False .. py:attribute:: log_reward_clip_min .. py:method:: loss(env, training_objects, recalculate_all_logprobs = True) :abstractmethod: Computes the loss given the training objects. :param env: The environment where the training objects are sampled from. :param training_objects: The objects to compute the loss with. :param recalculate_all_logprobs: If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available. :returns: The computed loss as a tensor. .. py:attribute:: loss_fn .. py:method:: loss_from_trajectories(env, trajectories, recalculate_all_logprobs = True) Helper method to compute loss directly from trajectories. This method converts trajectories to the appropriate training samples and computes the loss with the correct arguments based on the type of GFlowNet subclass. :param env: The environment where the training objects are sampled from. :param trajectories: The trajectories to compute the loss with. :param recalculate_all_logprobs: If True, always recalculate logprobs even if they exist. If False, use existing logprobs when available. :returns: The computed loss as a tensor. .. py:method:: sample_terminating_states(env, n) Rolls out the policy and returns the terminating states. :param env: The environment to sample terminating states from. :param n: Number of terminating states to sample. :returns: The sampled terminating states as a States object. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) :abstractmethod: Samples a specific number of complete trajectories from the environment. :param env: The environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the logprobs of the actions (useful for on-policy learning). :param save_estimator_outputs: Whether to save the estimator outputs (useful for off-policy learning with a tempered policy). :returns: A Trajectories object containing the sampled trajectories. .. py:method:: to_training_samples(trajectories) :abstractmethod: Converts trajectories to training samples. :param trajectories: The Trajectories object to convert. :returns: The training samples, type depends on the type of GFlowNet subclass. .. py:class:: PFBasedGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`GFlowNet`\ [\ :py:obj:`TrainingSampleType`\ ], :py:obj:`abc.ABC` A GFlowNet that uses forward (PF) and backward (PB) policy networks. .. attribute:: pf The forward policy estimator. .. attribute:: pb The backward policy estimator, or None if it can be ignored (e.g., the gflownet DAG is a tree, and pb is therefore always 1). .. attribute:: constant_pb Whether to ignore the backward policy estimator. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:attribute:: constant_pb :value: False .. py:attribute:: log_reward_clip_min .. py:attribute:: pb .. py:attribute:: pf .. py:method:: pf_pb_named_parameters() Returns a dictionary of named parameters containing 'pf' or 'pb' in their name. :returns: A dictionary of named parameters containing 'pf' or 'pb' in their name. .. py:method:: pf_pb_parameters() Returns a list of parameters containing 'pf' or 'pb' in their name. :returns: A list of parameters containing 'pf' or 'pb' in their name. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) Samples trajectories using the forward policy network. :param env: The environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the logprobs of the actions. :param save_estimator_outputs: Whether to save the estimator outputs. :param \*\*policy_kwargs: Additional keyword arguments for the sampler. :returns: A Trajectories object containing the sampled trajectories. .. py:data:: TrainingSampleType .. py:class:: TrajectoryBasedGFlowNet(pf, pb, constant_pb = False, log_reward_clip_min = float('-inf'), debug = False, loss_fn = None) Bases: :py:obj:`PFBasedGFlowNet`\ [\ :py:obj:`gfn.containers.Trajectories`\ ], :py:obj:`abc.ABC` A GFlowNet that operates on complete trajectories. .. attribute:: pf The forward policy module. .. attribute:: pb The backward policy module, or None if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: constant_pb Whether to ignore the backward policy estimator, e.g., if the gflownet DAG is a tree, and pb is therefore always 1. .. attribute:: log_reward_clip_min If finite, clips log rewards to this value. .. py:method:: get_pfs_and_pbs(trajectories, recalculate_all_logprobs = True) Evaluates forward and backward logprobs for each trajectory in the batch. More specifically, it evaluates $\log P_F(s' \mid s)$ and $\log P_B(s \mid s')$ for each transition in each trajectory in the batch. If recalculate_all_logprobs=True, we re-evaluate the logprobs of the trajectories using the current self.pf. Otherwise, the following applies: - If trajectories have logprobs attribute, use them - this is usually for on-policy learning. - Elif trajectories have estimator_outputs attribute, transform them into logprobs - this is usually for off-policy learning with a tempered policy. - Else (trajectories have none of them), re-evaluate the logprobs using the current self.pf - this is usually for off-policy learning with replay buffer. :param trajectories: The Trajectories object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :returns: A tuple of tensors of shape (max_length, batch_size) containing the log_pf and log_pb for each action in each trajectory. .. py:method:: get_scores(trajectories, recalculate_all_logprobs = True, env = None, *, log_rewards = None) Calculates scores for a batch of trajectories. The scores for each trajectory are defined as: $\log \left( \frac{P_F(\tau)}{P_B(\tau \mid x) R(x)} \right)$. :param trajectories: The Trajectories object to evaluate. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs. :param env: The environment (unused in base TB, but required by some subclasses such as RTB and SubTB). :param log_rewards: Optional custom log rewards tensor of shape (n_trajectories,). When None, uses the environment rewards from the trajectories. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: A tensor of shape (batch_size,) containing the scores for each trajectory. .. py:method:: logz_named_parameters() Returns named parameters containing 'logZ' in their name. Works for any subclass that registers a logZ parameter (e.g. :class:`TBGFlowNet`, :class:`RelativeTrajectoryBalanceGFlowNet`). Returns an empty dict for subclasses without logZ. .. py:method:: logz_parameters() Returns parameters containing 'logZ' in their name. Works for any subclass that registers a logZ parameter (e.g. :class:`TBGFlowNet`, :class:`RelativeTrajectoryBalanceGFlowNet`). Returns an empty list for subclasses without logZ. .. py:method:: to_training_samples(trajectories) Returns the input trajectories as training samples. :param trajectories: The Trajectories object to use as training samples. :returns: The same Trajectories object. .. py:method:: trajectory_log_probs_backward(trajectories) Evaluates backward logprobs only for each trajectory in the batch. .. py:method:: trajectory_log_probs_forward(trajectories, recalculate_all_logprobs = True) Evaluates forward logprobs only for each trajectory in the batch. .. py:function:: loss_reduce(loss, method) Utility function to handle loss aggregation strategies. :param loss: The tensor to reduce. :param method: The reduction method to use ('mean', 'sum', or 'none'). :returns: The reduced tensor.