diff options
| author | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-20 03:06:39 +0900 |
|---|---|---|
| committer | Hiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp> | 2017-11-20 03:06:39 +0900 |
| commit | 16b4e72fe6728e2e64d4c6357b7c73ac06868c1c (patch) | |
| tree | 657f0398b9a237ab46327d08f58a230b9581669b /become_yukarin/loss.py | |
| parent | 437a869590c989c184d33990b1d788149d073ee9 (diff) | |
aligner
Diffstat (limited to 'become_yukarin/loss.py')
| -rw-r--r-- | become_yukarin/loss.py | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py index 3d89908..c59747a 100644 --- a/become_yukarin/loss.py +++ b/become_yukarin/loss.py @@ -1,18 +1,19 @@ -from .config import LossConfig -from .model import Model - import chainer - from chainer import reporter +from .config import LossConfig +from .model import Aligner +from .model import Predictor + class Loss(chainer.link.Chain): - def __init__(self, config: LossConfig, predictor: Model): + def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner): super().__init__() self.config = config with self.init_scope(): self.predictor = predictor + self.aligner = aligner def __call__(self, input, target, mask): input = chainer.as_variable(input) @@ -20,6 +21,7 @@ class Loss(chainer.link.Chain): mask = chainer.as_variable(mask) h = input + h = self.aligner(h) y = self.predictor(h) loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask) |
