trw.train.optimizer_clipping
¶
Module Contents¶
Classes¶
Clips the gradient norm during optimization |
- class trw.train.optimizer_clipping.ClippingGradientNorm(optimizer_base: torch.optim.Optimizer, max_norm: float = 1.0, norm_type: float = 2.0)¶
Bases:
torch.optim.Optimizer
Clips the gradient norm during optimization
- step(self, closure=None)¶
Performs a single optimization step (parameter update).
- Parameters
closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.
Note
Unless otherwise specified, this function should not modify the
.grad
field of the parameters.