diff options
| author | jules@lens <julescarbon@gmail.com> | 2019-12-11 21:23:37 +0100 |
|---|---|---|
| committer | jules@lens <julescarbon@gmail.com> | 2019-12-11 21:23:37 +0100 |
| commit | db900052fa41e9872ddcd0057ee81cd508731593 (patch) | |
| tree | 8bc89c225460b48d4e7f2bafb2971de66b896d3f /inversion | |
| parent | 24840cc289f87fb1d01659aaa6eb00f2003f44b1 (diff) | |
| parent | bc9cb9fc5d5889c7251e9f0b997fddcb97e3d69e (diff) | |
Merge branch 'master' of asdf.us:biggan
Diffstat (limited to 'inversion')
| -rw-r--r-- | inversion/image_inversion.py | 52 |
1 files changed, 30 insertions, 22 deletions
diff --git a/inversion/image_inversion.py b/inversion/image_inversion.py index 5560172..a9c9fc3 100644 --- a/inversion/image_inversion.py +++ b/inversion/image_inversion.py @@ -343,15 +343,20 @@ sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) +if params.max_batches > 0: + NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE +else: + NUM_IMGS_TO_PROCESS = NUM_IMGS + # Output file. out_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'w') -out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='float32') -out_enc = out_file.create_dataset('encoding', [NUM_IMGS,] + ENC_SHAPE, dtype='float32') -out_lat = out_file.create_dataset('latent', [NUM_IMGS, Z_DIM], dtype='float32') -out_fns = out_file.create_dataset('fn', [NUM_IMGS], dtype=h5py.string_dtype()) +out_images = out_file.create_dataset('xtrain', [NUM_IMGS_TO_PROCESS,] + IMG_SHAPE, dtype='float32') +out_enc = out_file.create_dataset('encoding', [NUM_IMGS_TO_PROCESS,] + ENC_SHAPE, dtype='float32') +out_lat = out_file.create_dataset('latent', [NUM_IMGS_TO_PROCESS, Z_DIM], dtype='float32') +out_fns = out_file.create_dataset('fn', [NUM_IMGS_TO_PROCESS], dtype=h5py.string_dtype()) if COND_GAN: - out_labels = out_file.create_dataset('ytrain', (NUM_IMGS, N_CLASS,), dtype='float32') -out_err = out_file.create_dataset('err', (NUM_IMGS,)) + out_labels = out_file.create_dataset('ytrain', (NUM_IMGS_TO_PROCESS, N_CLASS,), dtype='float32') +out_err = out_file.create_dataset('err', (NUM_IMGS_TO_PROCESS,)) out_fns[:] = sample_fns @@ -397,8 +402,8 @@ for image_batch, label_batch in image_gen: if params.clipping or params.stochastic_clipping: sess.run(clip_latent) - # Every 100 iterations save logs with training information. - if it < 100 or it % 100 == 0: + # Save logs with training information. + if it % 500 == 0: # Log losses. etime = time.time() - start_time print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' @@ -442,23 +447,24 @@ for image_batch, label_batch in image_gen: inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}_{}.png'.format(SAMPLES_DIR, params.tag, it), inv_batch) - # Save linear interpolation between the actual and generated encodings. - if params.dist_loss and it % 1000 == 0: - enc_batch, gen_enc = sess.run([encoding, gen_encoding]) - for j in range(10): - custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0) - sess.run(encoding.assign(custom_enc)) - gen_images = sess.run(gen_img) - inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), - vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) - inv_batch = vs.grid_transform(inv_batch) - vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j), - inv_batch) - sess.run(encoding.assign(enc_batch)) - # It counter. it += 1 + if params.save_progress: + # Save linear interpolation between the actual and generated encodings. + if params.dist_loss: + enc_batch, gen_enc = sess.run([encoding, gen_encoding]) + for j in range(10): + custom_enc = gen_enc * (1-(j/10.0)) + enc_batch * (j/10.0) + sess.run(encoding.assign(custom_enc)) + gen_images = sess.run(gen_img) + inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), + vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) + inv_batch = vs.grid_transform(inv_batch) + vs.save_image('{}/progress_{}_{}_lat_{}.png'.format(SAMPLES_DIR,params.tag,it,j), + inv_batch) + sess.run(encoding.assign(enc_batch)) + # Save samples of inverted images. if SAMPLE_SIZE > 0: assert SAMPLE_SIZE <= BATCH_SIZE @@ -479,6 +485,8 @@ for image_batch, label_batch in image_gen: out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch out_pos += BATCH_SIZE + if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: + break print('Mean reconstruction error: {}'.format(np.mean(out_err))) print('Stdev reconstruction error: {}'.format(np.std(out_err))) |
