summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-20 03:06:39 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-20 03:06:39 +0900
commit16b4e72fe6728e2e64d4c6357b7c73ac06868c1c (patch)
tree657f0398b9a237ab46327d08f58a230b9581669b /become_yukarin/loss.py
parent437a869590c989c184d33990b1d788149d073ee9 (diff)
aligner
Diffstat (limited to 'become_yukarin/loss.py')
-rw-r--r--become_yukarin/loss.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py
index 3d89908..c59747a 100644
--- a/become_yukarin/loss.py
+++ b/become_yukarin/loss.py
@@ -1,18 +1,19 @@
-from .config import LossConfig
-from .model import Model
-
import chainer
-
from chainer import reporter
+from .config import LossConfig
+from .model import Aligner
+from .model import Predictor
+
class Loss(chainer.link.Chain):
- def __init__(self, config: LossConfig, predictor: Model):
+ def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner):
super().__init__()
self.config = config
with self.init_scope():
self.predictor = predictor
+ self.aligner = aligner
def __call__(self, input, target, mask):
input = chainer.as_variable(input)
@@ -20,6 +21,7 @@ class Loss(chainer.link.Chain):
mask = chainer.as_variable(mask)
h = input
+ h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)