trw.train.optimizer_clipping

Module Contents

Classes

ClippingGradientNorm

Clips the gradient norm during optimization

class trw.train.optimizer_clipping.ClippingGradientNorm(optimizer_base: torch.optim.Optimizer, max_norm: float = 1.0, norm_type: float = 2.0)

Bases: torch.optim.Optimizer

Clips the gradient norm during optimization

step(self, closure=None)

Performs a single optimization step (parameter update).

Parameters

closure (callable) – A closure that reevaluates the model and returns the loss. Optional for most optimizers.

Note

Unless otherwise specified, this function should not modify the .grad field of the parameters.