diff options
Diffstat (limited to 'optim.py')
| -rw-r--r-- | optim.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..95f9df2 --- /dev/null +++ b/optim.py @@ -0,0 +1,21 @@ +from torch.nn.functional import hardtanh + + +def gradient_clipping(optimizer, min=-1, max=1): + + class OptimizerWrapper(object): + + def step(self, closure): + def closure_wrapper(): + loss = closure() + for group in optimizer.param_groups: + for p in group['params']: + hardtanh(p.grad, min, max, inplace=True) + return loss + + return optimizer.step(closure_wrapper) + + def __getattr__(self, attr): + return getattr(optimizer, attr) + + return OptimizerWrapper() |
