summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/train.py b/train.py
index 27fd1fb..08ef2d9 100644
--- a/train.py
+++ b/train.py
@@ -1,18 +1,18 @@
import argparse
+from functools import partial
from pathlib import Path
-from chainer.iterators import MultiprocessIterator
from chainer import optimizers
from chainer import training
-from chainer.training import extensions
from chainer.dataset import convert
+from chainer.iterators import MultiprocessIterator
+from chainer.training import extensions
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
-from become_yukarin.model import create as create_model
from become_yukarin.loss import Loss
-
-from functools import partial
+from become_yukarin.model import create_aligner
+from become_yukarin.model import create_predictor
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
@@ -24,8 +24,9 @@ arguments.output.mkdir(exist_ok=True)
config.save_as_json((arguments.output / 'config.json').absolute())
# model
-predictor = create_model(config.model)
-model = Loss(config.loss, predictor=predictor)
+predictor = create_predictor(config.model)
+aligner = create_aligner(config.model)
+model = Loss(config.loss, predictor=predictor, aligner=aligner)
# dataset
dataset = create_dataset(config.dataset)