gfn.utils.training ================== .. py:module:: gfn.utils.training Attributes ---------- .. autoapisummary:: gfn.utils.training.logger Functions --------- .. autoapisummary:: gfn.utils.training.get_terminating_state_dist gfn.utils.training.grad_norm gfn.utils.training.lr_grad_ratio gfn.utils.training.param_norm gfn.utils.training.states_actions_tns_to_traj gfn.utils.training.validate gfn.utils.training.warm_up Module Contents --------------- .. py:function:: get_terminating_state_dist(env, states) [DEPRECATED] Use `env.get_terminating_state_dist(states)` instead. .. py:function:: grad_norm(params, p = 2) Returns the p-norm of all gradients in ``params`` (ignores params with no grad). Example: grad_norm(model.parameters()) # total L2 norm grad_norm(model.parameters(), p=float('inf')) # max-grad .. py:data:: logger .. py:function:: lr_grad_ratio(optimizer) Return (lr·‖g‖₂)/‖θ‖₂ for each param group. .. py:function:: param_norm(params, p = 2) Total p-norm of a collection of parameters. .. rubric:: Example model_pnorm = param_norm(model.parameters()) # L2 norm max_abs = param_norm(model.parameters(), p=float('inf')) .. py:function:: states_actions_tns_to_traj(states_tns, actions_tns, env, conditions = None) Converts raw state and action tensors into a Trajectories object. This utility function helps integrate external data (e.g., expert demonstrations) into the GFlowNet framework by converting raw tensors into proper Trajectories objects. The downstream GFN needs to be capable of recalculating all logprobs (e.g., PFBasedGFlowNets). :param states_tns: A tensor of shape `[traj_len, *state_shape]` containing states for a single trajectory. :param actions_tns: A tensor of shape `[traj_len]` containing discrete action indices. :param env: The discrete environment that defines the state/action spaces. :param conditions: An optional tensor of shape `[traj_len, *conditions_shape]` containing condition information for a single trajectory. :returns: A `Trajectories` object containing the converted states and actions. :raises ValueError: If tensor shapes are invalid or inconsistent. .. py:function:: validate(env, gflownet, n_validation_samples = 1000, visited_terminating_states = None) [DEPRECATED] Use `env.validate(gflownet, ...)` instead. .. py:function:: warm_up(replay_buf, optimizer, gflownet, env, n_epochs, batch_size, recalculate_all_logprobs = True) Performs a warm-up training phase for a GFlowNet agent. This utility function provides an example implementation of pre-training for GFlowNet agents. :param replay_buf: The replay buffer, which collects Trajectories. :param optimizer: Any `torch.optim` optimizer (e.g., Adam, SGD). :param gflownet: The GFlowNet instance to train. :param env: The environment instance. :param n_epochs: The number of epochs for the warm-up phase. :param batch_size: The number of trajectories to sample from the replay buffer per step. :param recalculate_all_logprobs: For `PFBasedGFlowNets` only, forces recalculation of all log probabilities. Useful when trajectories do not already have log probabilities. :returns: The trained GFlowNet instance.