diff options
Diffstat (limited to 'cli/app/search/search_class.py')
| -rw-r--r-- | cli/app/search/search_class.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py index 691a0a8..7961a0c 100644 --- a/cli/app/search/search_class.py +++ b/cli/app/search/search_class.py @@ -65,13 +65,13 @@ def find_nearest_vector_for_images(paths, opt_dims, opt_steps, opt_video, opt_ta if index == opt_limit: break out_fns[index] = os.path.basename(path) - fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index, + fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index, opt_tag, opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval, opt_clip_interval) if opt_video: export_video(fp_frames) sess.close() -def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index, +def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index, opt_tag, opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval, opt_clip_interval): """ Find the closest latent and class vectors for an image. Store the class vector in an HDF5. @@ -87,9 +87,6 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la z_initial = truncated_z_sample(batch_size, z_dim, truncation/2) y_initial = create_labels_uniform(batch_size, vocab_size) - z_lr = 0.001 - y_lr = 0.001 - input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) input_trunc = tf.compat.v1.constant(1.0) @@ -101,7 +98,7 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la ## clip the Z encoding - opt_clip = 1.0 + opt_clip = 2.0 clipped_encoding = tf.where(tf.abs(input_z) >= opt_clip, tf.random.uniform([batch_size, z_dim], minval=-opt_clip, maxval=opt_clip), input_z) @@ -145,12 +142,17 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers) loss = 1.0 * mse_loss + 1.0 * feat_loss + + z_lr = 0.001 + y_lr = 0.001 train_step_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) \ .minimize(loss, var_list=[input_z]) train_step_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) \ .minimize(loss, var_list=[input_y]) else: + z_lr = 0.001 + y_lr = 0.001 loss = tf.compat.v1.losses.mean_squared_error(target, output) train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') @@ -183,17 +185,22 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la print("Iterating!") if i % 20 == 0: print('iter: {}, loss: {}'.format(i, curr_loss)) - if opt_stochastic_clipping != 0 and (i % opt_stochastic_clipping) == 0: - sess.run(clip_latent) - if opt_label_clipping != 0 and (i % opt_label_clipping) == 0: - sess.run(clip_labels) - if opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0: + if i > 0: + if opt_stochastic_clipping != 0 and (i % opt_stochastic_clipping) == 0: + sess.run(clip_latent) + if opt_label_clipping != 0 and (i % opt_label_clipping) == 0: + sess.run(clip_labels) + if opt_video and opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / opt_snapshot_interval))), guess_im) except KeyboardInterrupt: pass + phi_guess = sess.run(output) + guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) + imwrite(join(app_cfg.DIR_OUTPUTS, 'frame_{}_final.png'.format(opt_tag)), guess_im) + z_guess, y_guess = sess.run([input_z, input_y]) out_images[index] = phi_target_for_inversion out_labels[index] = y_guess |
