diff options
Diffstat (limited to 'cli/app')
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 17 |
1 files changed, 9 insertions, 8 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index 9b9d466..58f0d86 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -94,12 +94,14 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l z_lr = 0.001 y_lr = 0.001 - input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) - input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) + input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32) + input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32) + input_z_sigmoid = tf.compat.v1.sigmoid(input_z) * 2.0 - 1.0 + input_y_sigmoid = tf.compat.v1.sigmoid(input_y) input_trunc = tf.compat.v1.constant(1.0) output = generator({ - 'z': input_z, - 'y': input_y, + 'z': input_z_sigmoid, + 'y': input_y_sigmoid, 'truncation': input_trunc, }) @@ -138,20 +140,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: - print("Iteration start") + print("Preparing to iterate...") for i in range(opt_steps): curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict) if i % 20 == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) - imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im) + imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im) print('iter: {}, loss: {}'.format(i, curr_loss)) except KeyboardInterrupt: pass - z_guess = sess.run(input_z) - y_guess = sess.run(input_y) + z_guess, y_guess = sess.run([input_z_sigmoid, input_y_sigmoid]) out_labels[index] = y_guess out_latent[index] = z_guess return fp_frames |
