From 2e308fe8e90276a892637be1bfa174e673ebf414 Mon Sep 17 00:00:00 2001 From: Piotr Kozakowski Date: Thu, 11 May 2017 17:49:12 +0200 Subject: Implement SampleRNN --- optim.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 optim.py (limited to 'optim.py') diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..95f9df2 --- /dev/null +++ b/optim.py @@ -0,0 +1,21 @@ +from torch.nn.functional import hardtanh + + +def gradient_clipping(optimizer, min=-1, max=1): + + class OptimizerWrapper(object): + + def step(self, closure): + def closure_wrapper(): + loss = closure() + for group in optimizer.param_groups: + for p in group['params']: + hardtanh(p.grad, min, max, inplace=True) + return loss + + return optimizer.step(closure_wrapper) + + def __getattr__(self, attr): + return getattr(optimizer, attr) + + return OptimizerWrapper() -- cgit v1.2.3-70-g09d2