gfn.gflownet.flow_matching ========================== .. py:module:: gfn.gflownet.flow_matching Classes ------- .. autoapisummary:: gfn.gflownet.flow_matching.FMGFlowNet Module Contents --------------- .. py:class:: FMGFlowNet(logF, alpha = 1.0, debug = False, loss_fn = None) Bases: :py:obj:`gfn.gflownet.base.GFlowNet`\ [\ :py:obj:`gfn.containers.StatesContainer`\ [\ :py:obj:`gfn.states.DiscreteStates`\ ]\ ] GFlowNet for the Flow Matching loss with an edge flow estimator. $\mathcal{O}_{edge}$ is the set of functions from the non-terminating edges to $\mathbb{R}^+$. Which is equivalent to the set of functions from the internal nodes (i.e. without $s_f$) to $(\mathbb{R})^{n_actions}$, without the exit action (No need for positivity if we parametrize log-flows). The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). .. attribute:: logF A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states). .. attribute:: alpha A scalar weight for the reward matching loss. Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface. .. py:attribute:: alpha :value: 1.0 .. py:method:: flow_matching_loss(env, states, reduction = 'mean') Computes the flow matching loss for the (non-initial) states. The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include $s_0$. The batch shape should be `(n_states,)`. As of now, only discrete environments are handled. :param env: The discrete environment where the states are sampled from. :param states: The DiscreteStates object to evaluate (should not include $s_0$). :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed flow matching loss as a tensor. The shape depends on the reduction method. .. py:attribute:: logF .. py:method:: loss(env, states_container, recalculate_all_logprobs = True, reduction = 'mean', *, log_rewards = None) Computes the flow matching loss for a batch of states. The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately. :param env: The discrete environment where the states are sampled from. :param states_container: The StatesContainer object containing both intermediary and terminating states. :param recalculate_all_logprobs: Whether to re-evaluate all logprobs (unused for FM). :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :param log_rewards: Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see "Towards Improving Exploration through Sibling Augmented GFlowNets", Madan et al., ICLR 2025). :returns: The computed flow matching loss as a tensor. The shape depends on the reduction method. .. py:method:: reward_matching_loss(env, terminating_states, log_rewards, reduction = 'mean') Computes the reward matching loss for the terminating states. :param env: The discrete environment where the states are sampled from (unused). :param terminating_states: The DiscreteStates object containing terminating states. :param conditions: Optional conditions tensor for conditional environments. :param log_rewards: The log rewards for the terminating states. :param reduction: The reduction method to use ('mean', 'sum', or 'none'). :returns: The computed reward matching loss as a tensor. The shape depends on the reduction method. .. py:method:: sample_trajectories(env, n, conditions = None, save_logprobs = False, save_estimator_outputs = False, **policy_kwargs) Samples trajectories using the edge flow estimator. :param env: The discrete environment to sample trajectories from. :param n: Number of trajectories to sample. :param conditions: Optional conditions tensor for conditional environments. :param save_logprobs: Whether to save the log-probabilities of the actions. :param save_estimator_outputs: Whether to save the estimator outputs. :param \*\*policy_kwargs: Additional keyword arguments for the sampler. :returns: A Trajectories object containing the sampled trajectories. .. py:method:: to_training_samples(trajectories) Converts trajectories to a StatesContainer for flow matching loss. :param trajectories: The Trajectories object to convert. :returns: A StatesContainer object containing all states from the trajectories.