summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/image_inversion.py52
1 files changed, 30 insertions, 22 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py
index 5560172..a9c9fc3 100644
--- a/inversion/image_inversion.py
+++ b/inversion/image_inversion.py
@@ -343,15 +343,20 @@ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
+if params.max_batches > 0:
+ NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE
+else:
+ NUM_IMGS_TO_PROCESS = NUM_IMGS
+
# Output file.
out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w')
-out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='float32')
-out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE, dtype='float32')
-out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM], dtype='float32')
-out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype())
+out_images = out_file.create_dataset('xtrain', [NUM_IMGS_TO_PROCESS,] + IMG_SHAPE, dtype='float32')
+out_enc = out_file.create_dataset('encoding', [NUM_IMGS_TO_PROCESS,] + ENC_SHAPE, dtype='float32')
+out_lat = out_file.create_dataset('latent', [NUM_IMGS_TO_PROCESS, Z_DIM], dtype='float32')
+out_fns = out_file.create_dataset('fn', [NUM_IMGS_TO_PROCESS], dtype=h5py.string_dtype())
if COND_GAN:
- out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='float32')
-out_err = out_file.create_dataset('err', (NUM_IMGS,))
+ out_labels = out_file.create_dataset('ytrain', (NUM_IMGS_TO_PROCESS, N_CLASS,), dtype='float32')
+out_err = out_file.create_dataset('err', (NUM_IMGS_TO_PROCESS,))
out_fns[:] = sample_fns
@@ -397,8 +402,8 @@ for image_batch, label_batch in image_gen:
if params.clipping or params.stochastic_clipping:
sess.run(clip_latent)
- # Every 100 iterations save logs with training information.
- if it < 100 or it % 100 == 0:
+ # Save logs with training information.
+ if it % 500 == 0:
# Log losses.
etime = time.time() - start_time
print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] '
@@ -442,23 +447,24 @@ for image_batch, label_batch in image_gen:
inv_batch = vs.grid_transform(inv_batch)
vs.save_image('{}/progress_{}_{}.png'.format(SAMPLES_DIR, params.tag, it), inv_batch)
- # Save linear interpolation between the actual and generated encodings.
- if params.dist_loss and it % 1000 == 0:
- enc_batch, gen_enc = sess.run([encoding, gen_encoding])
- for j in range(10):
- custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0)
- sess.run(encoding.assign(custom_enc))
- gen_images = sess.run(gen_img)
- inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]),
- vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
- inv_batch = vs.grid_transform(inv_batch)
- vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j),
- inv_batch)
- sess.run(encoding.assign(enc_batch))
-
# It counter.
it += 1
+ if params.save_progress:
+ # Save linear interpolation between the actual and generated encodings.
+ if params.dist_loss:
+ enc_batch, gen_enc = sess.run([encoding, gen_encoding])
+ for j in range(10):
+ custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0)
+ sess.run(encoding.assign(custom_enc))
+ gen_images = sess.run(gen_img)
+ inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]),
+ vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ inv_batch = vs.grid_transform(inv_batch)
+ vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j),
+ inv_batch)
+ sess.run(encoding.assign(enc_batch))
+
# Save samples of inverted images.
if SAMPLE_SIZE > 0:
assert SAMPLE_SIZE <= BATCH_SIZE
@@ -479,6 +485,8 @@ for image_batch, label_batch in image_gen:
out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch
out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch
out_pos += BATCH_SIZE
+ if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches:
+ break
print('Mean reconstruction error: {}'.format(np.mean(out_err)))
print('Stdev reconstruction error: {}'.format(np.std(out_err)))