summaryrefslogtreecommitdiff
path: root/become_yukarin/loss.py
diff options
context:
space:
mode:
Diffstat (limited to 'become_yukarin/loss.py')
-rw-r--r--become_yukarin/loss.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py
index 3d89908..c59747a 100644
--- a/become_yukarin/loss.py
+++ b/become_yukarin/loss.py
@@ -1,18 +1,19 @@
-from .config import LossConfig
-from .model import Model
-
import chainer
-
from chainer import reporter
+from .config import LossConfig
+from .model import Aligner
+from .model import Predictor
+
class Loss(chainer.link.Chain):
- def __init__(self, config: LossConfig, predictor: Model):
+ def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner):
super().__init__()
self.config = config
with self.init_scope():
self.predictor = predictor
+ self.aligner = aligner
def __call__(self, input, target, mask):
input = chainer.as_variable(input)
@@ -20,6 +21,7 @@ class Loss(chainer.link.Chain):
mask = chainer.as_variable(mask)
h = input
+ h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)