blob: c59747aa40bfc4d204174f1488970e41fd7f9f32 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
import chainer
from chainer import reporter
from .config import LossConfig
from .model import Aligner
from .model import Predictor
class Loss(chainer.link.Chain):
def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner):
super().__init__()
self.config = config
with self.init_scope():
self.predictor = predictor
self.aligner = aligner
def __call__(self, input, target, mask):
input = chainer.as_variable(input)
target = chainer.as_variable(target)
mask = chainer.as_variable(mask)
h = input
h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
loss = loss / chainer.functions.sum(mask)
reporter.report({'loss': loss}, self)
return loss * self.config.l1
|