gfn.gflownet.flow_matching¶
Classes¶
GFlowNet for the Flow Matching loss with an edge flow estimator. |
Module Contents¶
- class gfn.gflownet.flow_matching.FMGFlowNet(logF, alpha=1.0, debug=False, loss_fn=None)¶
Bases:
gfn.gflownet.base.GFlowNet[gfn.containers.StatesContainer[gfn.states.DiscreteStates]]GFlowNet for the Flow Matching loss with an edge flow estimator.
\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).
The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).
- Parameters:
alpha (float)
debug (bool)
loss_fn (gfn.gflownet.losses.RegressionLoss | None)
- logF¶
A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states).
- alpha¶
A scalar weight for the reward matching loss.
Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface.
- alpha = 1.0¶
- flow_matching_loss(env, states, reduction='mean')¶
Computes the flow matching loss for the (non-initial) states.
The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.
states (gfn.states.DiscreteStates) – The DiscreteStates object to evaluate (should not include \(s_0\)).
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed flow matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- logF¶
- loss(env, states_container, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)¶
Computes the flow matching loss for a batch of states.
The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.
states_container (gfn.containers.StatesContainer[gfn.states.DiscreteStates]) – The StatesContainer object containing both intermediary and terminating states.
recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs (unused for FM).
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).
- Returns:
The computed flow matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- reward_matching_loss(env, terminating_states, log_rewards, reduction='mean')¶
Computes the reward matching loss for the terminating states.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from (unused).
terminating_states (gfn.states.DiscreteStates) – The DiscreteStates object containing terminating states.
conditions – Optional conditions tensor for conditional environments.
log_rewards (torch.Tensor | None) – The log rewards for the terminating states.
reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).
- Returns:
The computed reward matching loss as a tensor. The shape depends on the reduction method.
- Return type:
torch.Tensor
- sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)¶
Samples trajectories using the edge flow estimator.
- Parameters:
env (gfn.env.DiscreteEnv) – The discrete environment to sample trajectories from.
n (int) – Number of trajectories to sample.
conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.
save_logprobs (bool) – Whether to save the log-probabilities of the actions.
save_estimator_outputs (bool) – Whether to save the estimator outputs.
**policy_kwargs (Any) – Additional keyword arguments for the sampler.
- Returns:
A Trajectories object containing the sampled trajectories.
- Return type:
- to_training_samples(trajectories)¶
Converts trajectories to a StatesContainer for flow matching loss.
- Parameters:
trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.
- Returns:
A StatesContainer object containing all states from the trajectories.
- Return type: