gfn.gflownet.flow_matching

Classes

FMGFlowNet

GFlowNet for the Flow Matching loss with an edge flow estimator.

Module Contents

class gfn.gflownet.flow_matching.FMGFlowNet(logF, alpha=1.0, debug=False, loss_fn=None)

Bases: gfn.gflownet.base.GFlowNet[gfn.containers.StatesContainer[gfn.states.DiscreteStates]]

GFlowNet for the Flow Matching loss with an edge flow estimator.

\(\mathcal{O}_{edge}\) is the set of functions from the non-terminating edges to \(\mathbb{R}^+\). Which is equivalent to the set of functions from the internal nodes (i.e. without \(s_f\)) to \((\mathbb{R})^{n_actions}\), without the exit action (No need for positivity if we parametrize log-flows).

The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).

Parameters:
logF

A DiscretePolicyEstimator or ConditionalDiscretePolicyEstimator for estimating the log flow of the edges (states -> next_states).

alpha

A scalar weight for the reward matching loss.

Flow Matching does not rely on PF/PB probability recomputation. Any trajectory sampling provided by this class is for diagnostics/visualization and can only use the default (non-recurrent) PolicyMixin interface.

alpha = 1.0
flow_matching_loss(env, states, reduction='mean')

Computes the flow matching loss for the (non-initial) states.

The Flow Matching loss is defined as the log-sum incoming flows minus log-sum outgoing flows. The states should not include \(s_0\). The batch shape should be (n_states,). As of now, only discrete environments are handled.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.

  • states (gfn.states.DiscreteStates) – The DiscreteStates object to evaluate (should not include \(s_0\)).

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed flow matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

logF
loss(env, states_container, recalculate_all_logprobs=True, reduction='mean', *, log_rewards=None)

Computes the flow matching loss for a batch of states.

The flow matching loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266). Unlike the original implementation, we allow more flexibility by treating the intermediary and terminating states separately.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from.

  • states_container (gfn.containers.StatesContainer[gfn.states.DiscreteStates]) – The StatesContainer object containing both intermediary and terminating states.

  • recalculate_all_logprobs (bool) – Whether to re-evaluate all logprobs (unused for FM).

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

  • log_rewards (torch.Tensor | None) – Optional custom log rewards tensor of shape (n_terminating_states,). When None, uses the environment rewards from the states container. Useful for intrinsic rewards (see “Towards Improving Exploration through Sibling Augmented GFlowNets”, Madan et al., ICLR 2025).

Returns:

The computed flow matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

reward_matching_loss(env, terminating_states, log_rewards, reduction='mean')

Computes the reward matching loss for the terminating states.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment where the states are sampled from (unused).

  • terminating_states (gfn.states.DiscreteStates) – The DiscreteStates object containing terminating states.

  • conditions – Optional conditions tensor for conditional environments.

  • log_rewards (torch.Tensor | None) – The log rewards for the terminating states.

  • reduction (str) – The reduction method to use (‘mean’, ‘sum’, or ‘none’).

Returns:

The computed reward matching loss as a tensor. The shape depends on the reduction method.

Return type:

torch.Tensor

sample_trajectories(env, n, conditions=None, save_logprobs=False, save_estimator_outputs=False, **policy_kwargs)

Samples trajectories using the edge flow estimator.

Parameters:
  • env (gfn.env.DiscreteEnv) – The discrete environment to sample trajectories from.

  • n (int) – Number of trajectories to sample.

  • conditions (torch.Tensor | None) – Optional conditions tensor for conditional environments.

  • save_logprobs (bool) – Whether to save the log-probabilities of the actions.

  • save_estimator_outputs (bool) – Whether to save the estimator outputs.

  • **policy_kwargs (Any) – Additional keyword arguments for the sampler.

Returns:

A Trajectories object containing the sampled trajectories.

Return type:

gfn.containers.Trajectories

to_training_samples(trajectories)

Converts trajectories to a StatesContainer for flow matching loss.

Parameters:

trajectories (gfn.containers.Trajectories) – The Trajectories object to convert.

Returns:

A StatesContainer object containing all states from the trajectories.

Return type:

gfn.containers.StatesContainer[gfn.states.DiscreteStates]