blob: 95f9df2bb9384256fd87ecd32e910fc10712dd76 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
from torch.nn.functional import hardtanh
def gradient_clipping(optimizer, min=-1, max=1):
class OptimizerWrapper(object):
def step(self, closure):
def closure_wrapper():
loss = closure()
for group in optimizer.param_groups:
for p in group['params']:
hardtanh(p.grad, min, max, inplace=True)
return loss
return optimizer.step(closure_wrapper)
def __getattr__(self, attr):
return getattr(optimizer, attr)
return OptimizerWrapper()
|