diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 15 |
1 files changed, 8 insertions, 7 deletions
@@ -1,18 +1,18 @@ import argparse +from functools import partial from pathlib import Path -from chainer.iterators import MultiprocessIterator from chainer import optimizers from chainer import training -from chainer.training import extensions from chainer.dataset import convert +from chainer.iterators import MultiprocessIterator +from chainer.training import extensions from become_yukarin.config import create_from_json from become_yukarin.dataset import create as create_dataset -from become_yukarin.model import create as create_model from become_yukarin.loss import Loss - -from functools import partial +from become_yukarin.model import create_aligner +from become_yukarin.model import create_predictor parser = argparse.ArgumentParser() parser.add_argument('config_json_path', type=Path) @@ -24,8 +24,9 @@ arguments.output.mkdir(exist_ok=True) config.save_as_json((arguments.output / 'config.json').absolute()) # model -predictor = create_model(config.model) -model = Loss(config.loss, predictor=predictor) +predictor = create_predictor(config.model) +aligner = create_aligner(config.model) +model = Loss(config.loss, predictor=predictor, aligner=aligner) # dataset dataset = create_dataset(config.dataset) |
