From 0b410c0bbacc147950438423835547b74859aac1 Mon Sep 17 00:00:00 2001 From: Hiroshiba Kazuyuki Date: Fri, 2 Feb 2018 16:59:33 +0900 Subject: pix2pix convertモデル MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) (limited to 'train.py') diff --git a/train.py b/train.py index 26490ce..c01915a 100644 --- a/train.py +++ b/train.py @@ -27,12 +27,11 @@ config.save_as_json((arguments.output / 'config.json').absolute()) # model if config.train.gpu >= 0: cuda.get_device_from_id(config.train.gpu).use() -predictor, aligner, discriminator = create(config.model) -models = {'predictor': predictor} -if aligner is not None: - models['aligner'] = aligner -if discriminator is not None: - models['discriminator'] = discriminator +predictor, discriminator = create(config.model) +models = { + 'predictor': predictor, + 'discriminator': discriminator, +} # dataset dataset = create_dataset(config.dataset) @@ -43,7 +42,7 @@ train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batch # optimizer def create_optimizer(model): - optimizer = optimizers.Adam() + optimizer = optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.999) optimizer.setup(model) return optimizer @@ -55,7 +54,6 @@ converter = partial(convert.concat_examples, padding=0) updater = Updater( loss_config=config.loss, predictor=predictor, - aligner=aligner, discriminator=discriminator, device=config.train.gpu, iterator=train_iter, -- cgit v1.2.3-70-g09d2