summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py51
1 files changed, 40 insertions, 11 deletions
diff --git a/train.py b/train.py
index a3bea0f..f3d4885 100644
--- a/train.py
+++ b/train.py
@@ -11,7 +11,7 @@ from chainer.training import extensions
from become_yukarin.config import create_from_json
from become_yukarin.dataset import create as create_dataset
-from become_yukarin.loss import Loss
+from become_yukarin.updater import Updater
from become_yukarin.model import create
parser = argparse.ArgumentParser()
@@ -26,8 +26,12 @@ config.save_as_json((arguments.output / 'config.json').absolute())
# model
if config.train.gpu >= 0:
cuda.get_device_from_id(config.train.gpu).use()
-predictor, aligner = create(config.model)
-model = Loss(config.loss, predictor=predictor, aligner=aligner)
+predictor, aligner, discriminator = create(config.model)
+models = {'predictor': predictor}
+if aligner is not None:
+ models['aligner'] = aligner
+if discriminator is not None:
+ models['discriminator'] = discriminator
# dataset
dataset = create_dataset(config.dataset)
@@ -35,24 +39,42 @@ train_iter = MultiprocessIterator(dataset['train'], config.train.batchsize)
test_iter = MultiprocessIterator(dataset['test'], config.train.batchsize, repeat=False, shuffle=False)
train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batchsize, repeat=False, shuffle=False)
+
# optimizer
-optimizer = optimizers.Adam()
-optimizer.setup(model)
+def create_optimizer(model):
+ optimizer = optimizers.Adam()
+ optimizer.setup(model)
+ return optimizer
+
+
+opts = {key: create_optimizer(model) for key, model in models.items()}
+
+# updater
+converter = partial(convert.concat_examples, padding=0)
+updater = Updater(
+ loss_config=config.loss,
+ model_config=config.model,
+ predictor=predictor,
+ aligner=aligner,
+ discriminator=discriminator,
+ device=config.train.gpu,
+ iterator=train_iter,
+ optimizer=opts,
+ converter=converter,
+)
# trainer
trigger_log = (config.train.log_iteration, 'iteration')
trigger_snapshot = (config.train.snapshot_iteration, 'iteration')
-converter = partial(convert.concat_examples, padding=0)
-updater = training.StandardUpdater(train_iter, optimizer, device=config.train.gpu, converter=converter)
trainer = training.Trainer(updater, out=arguments.output)
-ext = extensions.Evaluator(test_iter, model, converter, device=config.train.gpu)
+ext = extensions.Evaluator(test_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
trainer.extend(ext, name='test', trigger=trigger_log)
-ext = extensions.Evaluator(train_eval_iter, model, converter, device=config.train.gpu)
+ext = extensions.Evaluator(train_eval_iter, models, converter, device=config.train.gpu, eval_func=updater.forward)
trainer.extend(ext, name='train', trigger=trigger_log)
-trainer.extend(extensions.dump_graph('main/loss', out_name='graph.dot'))
+trainer.extend(extensions.dump_graph('predictor/loss', out_name='graph.dot'))
ext = extensions.snapshot_object(predictor, filename='predictor_{.updater.iteration}.npz')
trainer.extend(ext, trigger=trigger_snapshot)
@@ -61,7 +83,14 @@ trainer.extend(extensions.LogReport(trigger=trigger_log, log_name='log.txt'))
if extensions.PlotReport.available():
trainer.extend(extensions.PlotReport(
- y_keys=['main/loss', 'test/main/loss', 'train/main/loss'],
+ y_keys=[
+ 'predictor/loss',
+ 'predictor/l1',
+ 'test/predictor/loss',
+ 'train/predictor/loss',
+ 'discriminator/fake',
+ 'discriminator/true',
+ ],
x_key='iteration',
file_name='loss.png',
trigger=trigger_log,