summaryrefslogtreecommitdiff
path: root/inversion/image_inversion.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-11 10:26:48 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-11 10:26:48 +0100
commit5edd80d1ac7673f31e503226529b0a826f907944 (patch)
tree645a5d11c746351a63f6671c0c135f02b734223b /inversion/image_inversion.py
parent7aad285136f97d737ef251ae35ed77404a847bb7 (diff)
paths
Diffstat (limited to 'inversion/image_inversion.py')
-rw-r--r--inversion/image_inversion.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 5556778..2927eac 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -68,11 +68,12 @@ params = params.Params(sys.argv[1])
# --------------------------
# Global directories.
# --------------------------
+LATENT_TAG = 'latent_' if params.inv_layer == 'latent' else 'dense_'
BATCH_SIZE = params.batch_size
SAMPLE_SIZE = params.sample_size
-LOGS_DIR = 'logs'
-SAMPLES_DIR = 'samples'
-INVERSES_DIR = 'inverses'
+LOGS_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'logs')
+SAMPLES_DIR = os.path.join('inverses', params.tag, LATENT_TAG, 'samples')
+INVERSES_DIR = os.path.join('inverses', params.tag)
if not os.path.exists(LOGS_DIR):
os.makedirs(LOGS_DIR)
if not os.path.exists(SAMPLES_DIR):
@@ -394,7 +395,7 @@ for image_batch, label_batch in image_gen:
sess.run(clip_latent)
# Every 100 iterations save logs with training information.
- if it % 100 == 99:
+ if it < 100 or it % 100 == 0:
# Log losses.
etime = time.time() - start_time
print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] '
@@ -436,7 +437,7 @@ for image_batch, label_batch in image_gen:
inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]),
vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
inv_batch = vs.grid_transform(inv_batch)
- vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch)
+ vs.save_image('{}/progress_{}_{}.png'.format(SAMPLES_DIR, params.tag, it), inv_batch)
# Save linear interpolation between the actual and generated encodings.
if params.dist_loss and it % 1000 == 999:
@@ -448,7 +449,7 @@ for image_batch, label_batch in image_gen:
inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]),
vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
inv_batch = vs.grid_transform(inv_batch)
- vs.save_image('{}/progress_{}_lat_{}.png'.format(SAMPLES_DIR,it,j),
+ vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j),
inv_batch)
sess.run(encoding.assign(enc_batch))
@@ -462,7 +463,7 @@ for image_batch, label_batch in image_gen:
inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]),
vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
inv_batch = vs.grid_transform(inv_batch)
- vs.save_image('{}/{}.png'.format(SAMPLES_DIR, out_pos), inv_batch)
+ vs.save_image('{}/{}_{}.png'.format(SAMPLES_DIR, params.tag, out_pos), inv_batch)
print('Saved samples for out_pos: {}.'.format(out_pos))
# Save images that are ready.