diff options
| author | Piotr Kozakowski <kozak000@gmail.com> | 2017-05-11 17:49:12 +0200 |
|---|---|---|
| committer | Piotr Kozakowski <kozak000@gmail.com> | 2017-06-29 15:37:26 +0200 |
| commit | 2e308fe8e90276a892637be1bfa174e673ebf414 (patch) | |
| tree | 4ff187b37d16476cc936aba84184b8feca9c8612 /optim.py | |
| parent | 253860fdb0949f0eab6abff09369b0a1236b541a (diff) | |
Implement SampleRNN
Diffstat (limited to 'optim.py')
| -rw-r--r-- | optim.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/optim.py b/optim.py new file mode 100644 index 0000000..95f9df2 --- /dev/null +++ b/optim.py @@ -0,0 +1,21 @@ +from torch.nn.functional import hardtanh + + +def gradient_clipping(optimizer, min=-1, max=1): + + class OptimizerWrapper(object): + + def step(self, closure): + def closure_wrapper(): + loss = closure() + for group in optimizer.param_groups: + for p in group['params']: + hardtanh(p.grad, min, max, inplace=True) + return loss + + return optimizer.step(closure_wrapper) + + def __getattr__(self, attr): + return getattr(optimizer, attr) + + return OptimizerWrapper() |
