gfn.utils.training

Attributes

logger

Functions

get_terminating_state_dist(env, states)

[DEPRECATED] Use env.get_terminating_state_dist(states) instead.

grad_norm(params[, p])

Returns the p-norm of all gradients in params (ignores params with no grad).

lr_grad_ratio(optimizer)

Return (lr·‖g‖₂)/‖θ‖₂ for each param group.

param_norm(params[, p])

Total p-norm of a collection of parameters.

states_actions_tns_to_traj(states_tns, actions_tns, env)

Converts raw state and action tensors into a Trajectories object.

validate(env, gflownet[, n_validation_samples, ...])

[DEPRECATED] Use env.validate(gflownet, ...) instead.

warm_up(replay_buf, optimizer, gflownet, env, ...[, ...])

Performs a warm-up training phase for a GFlowNet agent.

Module Contents

gfn.utils.training.get_terminating_state_dist(env, states)

[DEPRECATED] Use env.get_terminating_state_dist(states) instead.

Parameters:
Return type:

torch.Tensor

gfn.utils.training.grad_norm(params, p=2)

Returns the p-norm of all gradients in params (ignores params with no grad). Example: grad_norm(model.parameters()) # total L2 norm

grad_norm(model.parameters(), p=float(‘inf’)) # max-grad

Parameters:
  • params (Iterable[torch.nn.Parameter])

  • p (float)

Return type:

float

gfn.utils.training.logger
gfn.utils.training.lr_grad_ratio(optimizer)

Return (lr·‖g‖₂)/‖θ‖₂ for each param group.

Parameters:

optimizer (torch.optim.Optimizer)

Return type:

list[float]

gfn.utils.training.param_norm(params, p=2)

Total p-norm of a collection of parameters. .. rubric:: Example

model_pnorm = param_norm(model.parameters()) # L2 norm max_abs = param_norm(model.parameters(), p=float(‘inf’))

Parameters:
  • params (Iterable[torch.nn.Parameter])

  • p (float)

Return type:

float

gfn.utils.training.states_actions_tns_to_traj(states_tns, actions_tns, env, conditions=None)

Converts raw state and action tensors into a Trajectories object.

This utility function helps integrate external data (e.g., expert demonstrations) into the GFlowNet framework by converting raw tensors into proper Trajectories objects. The downstream GFN needs to be capable of recalculating all logprobs (e.g., PFBasedGFlowNets).

Parameters:
  • states_tns (torch.Tensor) – A tensor of shape [traj_len, *state_shape] containing states for a single trajectory.

  • actions_tns (torch.Tensor) – A tensor of shape [traj_len] containing discrete action indices.

  • env (gfn.env.DiscreteEnv) – The discrete environment that defines the state/action spaces.

  • conditions (torch.Tensor | None) – An optional tensor of shape [traj_len, *conditions_shape] containing condition information for a single trajectory.

Returns:

A Trajectories object containing the converted states and actions.

Raises:

ValueError – If tensor shapes are invalid or inconsistent.

Return type:

gfn.samplers.Trajectories

gfn.utils.training.validate(env, gflownet, n_validation_samples=1000, visited_terminating_states=None)

[DEPRECATED] Use env.validate(gflownet, …) instead.

Parameters:
Return type:

Tuple[Dict[str, float], gfn.states.DiscreteStates | None]

gfn.utils.training.warm_up(replay_buf, optimizer, gflownet, env, n_epochs, batch_size, recalculate_all_logprobs=True)

Performs a warm-up training phase for a GFlowNet agent.

This utility function provides an example implementation of pre-training for GFlowNet agents.

Parameters:
  • replay_buf (gfn.containers.ReplayBuffer) – The replay buffer, which collects Trajectories.

  • optimizer (torch.optim.Optimizer) – Any torch.optim optimizer (e.g., Adam, SGD).

  • gflownet (gfn.gflownet.GFlowNet) – The GFlowNet instance to train.

  • env (gfn.env.Env) – The environment instance.

  • n_epochs (int) – The number of epochs for the warm-up phase.

  • batch_size (int) – The number of trajectories to sample from the replay buffer per step.

  • recalculate_all_logprobs (bool) – For PFBasedGFlowNets only, forces recalculation of all log probabilities. Useful when trajectories do not already have log probabilities.

Returns:

The trained GFlowNet instance.