diff options
Diffstat (limited to 'cli')
| -rw-r--r-- | cli/app/commands/biggan/search_class.py | 20 |
1 files changed, 9 insertions, 11 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py index 8895a67..658f4a8 100644 --- a/cli/app/commands/biggan/search_class.py +++ b/cli/app/commands/biggan/search_class.py @@ -92,8 +92,7 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2)) input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1)) - # input_trunc = tf.compat.v1.constant(1.0) - input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None) + input_trunc = tf.compat.v1.constant(1.0) output = generator({ 'z': input_z, 'y': input_y, @@ -125,20 +124,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) + feed_dict = { + target: phi_target, + } + # feed_dict = {input_z: z, input_y: y, input_trunc: truncation} - phi_start = sess.run(output, { - input_trunc: truncation, - }) + phi_start = sess.run(output, feed_dict=feed_dict) start_im = imconvert_uint8(phi_start) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im) try: + print("Iteration start") for i in range(opt_steps): - feed_dict = { - target: phi_target, - input_trunc: truncation, - } - sess.run([train_step_z, train_step_y, loss], feed_dict) + sess.run([train_step_z, train_step_y, loss], feed_dict=feed_dict) phi_guess = sess.run(output) guess_im = imconvert_uint8(phi_guess) @@ -169,7 +167,7 @@ def export_video(fp_frames): shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames)) def load_target_image(opt_fp_in): - print("Processing {}".format(opt_fp_in)) + print("Loading {}".format(opt_fp_in)) fn = os.path.basename(opt_fp_in) fbase, ext = os.path.splitext(fn) fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) |
