summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/train.py b/train.py
index 26490ce..c01915a 100644
--- a/train.py
+++ b/train.py
@@ -27,12 +27,11 @@ config.save_as_json((arguments.output / 'config.json').absolute())
# model
if config.train.gpu >= 0:
cuda.get_device_from_id(config.train.gpu).use()
-predictor, aligner, discriminator = create(config.model)
-models = {'predictor': predictor}
-if aligner is not None:
- models['aligner'] = aligner
-if discriminator is not None:
- models['discriminator'] = discriminator
+predictor, discriminator = create(config.model)
+models = {
+ 'predictor': predictor,
+ 'discriminator': discriminator,
+}
# dataset
dataset = create_dataset(config.dataset)
@@ -43,7 +42,7 @@ train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batch
# optimizer
def create_optimizer(model):
- optimizer = optimizers.Adam()
+ optimizer = optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.999)
optimizer.setup(model)
return optimizer
@@ -55,7 +54,6 @@ converter = partial(convert.concat_examples, padding=0)
updater = Updater(
loss_config=config.loss,
predictor=predictor,
- aligner=aligner,
discriminator=discriminator,
device=config.train.gpu,
iterator=train_iter,