summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-22 23:50:31 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2017-11-22 23:50:31 +0900
commit9f87a74de09e38f9d8f3e7ebb5fd26fac44a3b0e (patch)
treeae662b5319256e3864877cacbd21c527f33448f0
parentd6af2a851644afe253b97461b35138011a479a95 (diff)
can remove aligner
-rw-r--r--become_yukarin/config.py2
-rw-r--r--become_yukarin/loss.py5
-rw-r--r--become_yukarin/model.py10
-rw-r--r--train.py6
4 files changed, 17 insertions, 6 deletions
diff --git a/become_yukarin/config.py b/become_yukarin/config.py
index 07f35fd..0efbf04 100644
--- a/become_yukarin/config.py
+++ b/become_yukarin/config.py
@@ -34,6 +34,7 @@ class ModelConfig(NamedTuple):
out_size: int
aligner_out_time_length: int
disable_last_rnn: bool
+ enable_aligner: bool
class LossConfig(NamedTuple):
@@ -104,6 +105,7 @@ def create_from_json(s: Union[str, Path]):
out_size=d['model']['out_size'],
aligner_out_time_length=d['model']['aligner_out_time_length'],
disable_last_rnn=d['model']['disable_last_rnn'],
+ enable_aligner=d['model']['enable_aligner'],
),
loss=LossConfig(
l1=d['loss']['l1'],
diff --git a/become_yukarin/loss.py b/become_yukarin/loss.py
index c59747a..b2b03fc 100644
--- a/become_yukarin/loss.py
+++ b/become_yukarin/loss.py
@@ -7,7 +7,7 @@ from .model import Predictor
class Loss(chainer.link.Chain):
- def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner):
+ def __init__(self, config: LossConfig, predictor: Predictor, aligner: Aligner = None):
super().__init__()
self.config = config
@@ -21,7 +21,8 @@ class Loss(chainer.link.Chain):
mask = chainer.as_variable(mask)
h = input
- h = self.aligner(h)
+ if self.aligner is not None:
+ h = self.aligner(h)
y = self.predictor(h)
loss = chainer.functions.sum(chainer.functions.absolute_error(y, target) * mask)
diff --git a/become_yukarin/model.py b/become_yukarin/model.py
index 3b5102e..c475685 100644
--- a/become_yukarin/model.py
+++ b/become_yukarin/model.py
@@ -212,8 +212,18 @@ def create_predictor(config: ModelConfig):
def create_aligner(config: ModelConfig):
+ assert config.enable_aligner
aligner = Aligner(
in_size=config.in_channels,
out_time_length=config.aligner_out_time_length,
)
return aligner
+
+
+def create(config: ModelConfig):
+ predictor = create_predictor(config)
+ if config.enable_aligner:
+ aligner = create_aligner(config)
+ else:
+ aligner = None
+ return predictor, aligner
diff --git a/train.py b/train.py
index a9f4e79..a3bea0f 100644
--- a/train.py
+++ b/train.py
@@ -12,8 +12,7 @@ from chainer.training import extensions
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
from become_yukarin.loss import Loss
-from become_yukarin.model import create_aligner
-from become_yukarin.model import create_predictor
+from become_yukarin.model import create
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
@@ -27,8 +26,7 @@ config.save_as_json((arguments.output / 'config.json').absolute())
# model
if config.train.gpu >= 0:
cuda.get_device_from_id(config.train.gpu).use()
-predictor = create_predictor(config.model)
-aligner = create_aligner(config.model)
+predictor, aligner = create(config.model)
model = Loss(config.loss, predictor=predictor, aligner=aligner)
# dataset