diff options
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index bd260fb..9076e26 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -94,7 +94,8 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) - input_trunc = tf.compat.v1.constant(1.0) + # input_trunc = tf.compat.v1.constant(1.0) + input_trunc = tf.placeholder(tf.float32, shape=None) output = generator({ 'z': input_z, 'y': input_y, @@ -123,13 +124,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l phi_target = np.repeat(phi_target, batch_size, axis=0) # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} - phi_start = sess.run(output) + phi_start = sess.run(output, { + input_trunc: truncation, + }) start_im = imconvert_uint8(phi_start) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: for i in range(opt_steps): - sess.run([train_step_z, train_step_y]) + feed_dict = { + target: phi_target, + input_trunc: truncation, + } + sess.run([train_step_z, train_step_y, loss], feed_dict) phi_guess = sess.run(output) guess_im = imconvert_uint8(phi_guess) |
