gfn.utils.training¶
Attributes¶
Functions¶
|
[DEPRECATED] Use env.get_terminating_state_dist(states) instead. |
|
Returns the p-norm of all gradients in |
|
Return (lr·‖g‖₂)/‖θ‖₂ for each param group. |
|
Total p-norm of a collection of parameters. |
|
Converts raw state and action tensors into a Trajectories object. |
|
[DEPRECATED] Use env.validate(gflownet, ...) instead. |
|
Performs a warm-up training phase for a GFlowNet agent. |
Module Contents¶
- gfn.utils.training.get_terminating_state_dist(env, states)¶
[DEPRECATED] Use env.get_terminating_state_dist(states) instead.
- Parameters:
env (gfn.env.DiscreteEnv)
states (gfn.states.DiscreteStates)
- Return type:
torch.Tensor
- gfn.utils.training.grad_norm(params, p=2)¶
Returns the p-norm of all gradients in
params(ignores params with no grad). Example: grad_norm(model.parameters()) # total L2 normgrad_norm(model.parameters(), p=float(‘inf’)) # max-grad
- Parameters:
params (Iterable[torch.nn.Parameter])
p (float)
- Return type:
float
- gfn.utils.training.logger¶
- gfn.utils.training.lr_grad_ratio(optimizer)¶
Return (lr·‖g‖₂)/‖θ‖₂ for each param group.
- Parameters:
optimizer (torch.optim.Optimizer)
- Return type:
list[float]
- gfn.utils.training.param_norm(params, p=2)¶
Total p-norm of a collection of parameters. .. rubric:: Example
model_pnorm = param_norm(model.parameters()) # L2 norm max_abs = param_norm(model.parameters(), p=float(‘inf’))
- Parameters:
params (Iterable[torch.nn.Parameter])
p (float)
- Return type:
float
- gfn.utils.training.states_actions_tns_to_traj(states_tns, actions_tns, env, conditions=None)¶
Converts raw state and action tensors into a Trajectories object.
This utility function helps integrate external data (e.g., expert demonstrations) into the GFlowNet framework by converting raw tensors into proper Trajectories objects. The downstream GFN needs to be capable of recalculating all logprobs (e.g., PFBasedGFlowNets).
- Parameters:
states_tns (torch.Tensor) – A tensor of shape [traj_len, *state_shape] containing states for a single trajectory.
actions_tns (torch.Tensor) – A tensor of shape [traj_len] containing discrete action indices.
env (gfn.env.DiscreteEnv) – The discrete environment that defines the state/action spaces.
conditions (torch.Tensor | None) – An optional tensor of shape [traj_len, *conditions_shape] containing condition information for a single trajectory.
- Returns:
A Trajectories object containing the converted states and actions.
- Raises:
ValueError – If tensor shapes are invalid or inconsistent.
- Return type:
gfn.samplers.Trajectories
- gfn.utils.training.validate(env, gflownet, n_validation_samples=1000, visited_terminating_states=None)¶
[DEPRECATED] Use env.validate(gflownet, …) instead.
- Parameters:
env (gfn.env.DiscreteEnv)
gflownet (gfn.gflownet.GFlowNet)
n_validation_samples (int)
visited_terminating_states (Optional[gfn.states.DiscreteStates])
- Return type:
Tuple[Dict[str, float], gfn.states.DiscreteStates | None]
- gfn.utils.training.warm_up(replay_buf, optimizer, gflownet, env, n_epochs, batch_size, recalculate_all_logprobs=True)¶
Performs a warm-up training phase for a GFlowNet agent.
This utility function provides an example implementation of pre-training for GFlowNet agents.
- Parameters:
replay_buf (gfn.containers.ReplayBuffer) – The replay buffer, which collects Trajectories.
optimizer (torch.optim.Optimizer) – Any torch.optim optimizer (e.g., Adam, SGD).
gflownet (gfn.gflownet.GFlowNet) – The GFlowNet instance to train.
env (gfn.env.Env) – The environment instance.
n_epochs (int) – The number of epochs for the warm-up phase.
batch_size (int) – The number of trajectories to sample from the replay buffer per step.
recalculate_all_logprobs (bool) – For PFBasedGFlowNets only, forces recalculation of all log probabilities. Useful when trajectories do not already have log probabilities.
- Returns:
The trained GFlowNet instance.