summaryrefslogtreecommitdiff
path: root/train.py
diff options
context:
space:
mode:
authorHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-02-02 16:59:33 +0900
committerHiroshiba Kazuyuki <kazuyuki_hiroshiba@dwango.co.jp>2018-02-10 20:52:52 +0900
commit0b410c0bbacc147950438423835547b74859aac1 (patch)
treea5f8dc39e4d980bcfb09b21f47bf715d860c2e45 /train.py
parent4ff5252dbdc0cdaeecc7fbe399c629e4d29de3a3 (diff)
pix2pix convertモデル
Diffstat (limited to 'train.py')
-rw-r--r--train.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/train.py b/train.py
index 26490ce..c01915a 100644
--- a/train.py
+++ b/train.py
@@ -27,12 +27,11 @@ config.save_as_json((arguments.output / 'config.json').absolute())
# model
if config.train.gpu >= 0:
cuda.get_device_from_id(config.train.gpu).use()
-predictor, aligner, discriminator = create(config.model)
-models = {'predictor': predictor}
-if aligner is not None:
- models['aligner'] = aligner
-if discriminator is not None:
- models['discriminator'] = discriminator
+predictor, discriminator = create(config.model)
+models = {
+ 'predictor': predictor,
+ 'discriminator': discriminator,
+}
# dataset
dataset = create_dataset(config.dataset)
@@ -43,7 +42,7 @@ train_eval_iter = MultiprocessIterator(dataset['train_eval'], config.train.batch
# optimizer
def create_optimizer(model):
- optimizer = optimizers.Adam()
+ optimizer = optimizers.Adam(alpha=0.0002, beta1=0.5, beta2=0.999)
optimizer.setup(model)
return optimizer
@@ -55,7 +54,6 @@ converter = partial(convert.concat_examples, padding=0)
updater = Updater(
loss_config=config.loss,
predictor=predictor,
- aligner=aligner,
discriminator=discriminator,
device=config.train.gpu,
iterator=train_iter,