blob: b2b03fc97d461b1943f23cc8c5d2b7396f044613 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
import chainer
from chainer import reporter
from .config import LossConfig
from .model import Aligner
from .model import Predictor
class Loss(chainer.link.Chain):
def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None):
super().__init__()
self.config = config
with self.init_scope():
self.predictor = predictor
self.aligner = aligner
def __call__(self, input, target, mask):
input = chainer.as_variable(input)
target = chainer.as_variable(target)
mask = chainer.as_variable(mask)
h = input
if self.aligner is not None:
h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
loss = loss / chainer.functions.sum(mask)
reporter.report({'loss': loss}, self)
return loss * self.config.l1
|