blob: 3d89908befc69413ddfb8c8c447cd03a9d1c57f0 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
from .config import LossConfig
from .model import Model
import chainer
from chainer import reporter
class Loss(chainer.link.Chain):
def __init__(self, config: LossConfig, predictor: Model):
super().__init__()
self.config = config
with self.init_scope():
self.predictor = predictor
def __call__(self, input, target, mask):
input = chainer.as_variable(input)
target = chainer.as_variable(target)
mask = chainer.as_variable(mask)
h = input
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
loss = loss / chainer.functions.sum(mask)
reporter.report({'loss': loss}, self)
return loss * self.config.l1
|