summaryrefslogtreecommitdiff
path: root/optim.py
blob: 95f9df2bb9384256fd87ecd32e910fc10712dd76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from torch.nn.functional import hardtanh


def gradient_clipping(optimizer, min=-1, max=1):

    class OptimizerWrapper(object):

        def step(self, closure):
            def closure_wrapper():
                loss = closure()
                for group in optimizer.param_groups:
                    for p in group['params']:
                        hardtanh(p.grad, min, max, inplace=True)
                return loss
            
            return optimizer.step(closure_wrapper)

        def __getattr__(self, attr):
            return getattr(optimizer, attr)

    return OptimizerWrapper()