summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/train.py b/train.py
index a9f4e79..a3bea0f 100644
--- a/train.py
+++ b/train.py
@@ -12,8 +12,7 @@ from chainer.training import extensions
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
from become_yukarin.loss import Loss
-from become_yukarin.model import create_aligner
-from become_yukarin.model import create_predictor
+from become_yukarin.model import create
parser = argparse.ArgumentParser()
parser.add_argument('config_json_path', type=Path)
@@ -27,8 +26,7 @@ config.save_as_json((arguments.output / 'config.json').absolute())
# model
if config.train.gpu >= 0:
cuda.get_device_from_id(config.train.gpu).use()
-predictor = create_predictor(config.model)
-aligner = create_aligner(config.model)
+predictor, aligner = create(config.model)
model = Loss(config.loss, predictor=predictor, aligner=aligner)
# dataset