summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search_class.py20
1 files changed, 9 insertions, 11 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index 8895a67..658f4a8 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -92,8 +92,7 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
- # input_trunc = tf.compat.v1.constant(1.0)
- input_trunc = tf.compat.v1.placeholder(tf.float32, shape=None)
+ input_trunc = tf.compat.v1.constant(1.0)
output = generator({
'z': input_z,
'y': input_y,
@@ -125,20 +124,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.compat.v1.tables_initializer())
+ feed_dict = {
+ target: phi_target,
+ }
+
# feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
- phi_start = sess.run(output, {
- input_trunc: truncation,
- })
+ phi_start = sess.run(output, feed_dict=feed_dict)
start_im = imconvert_uint8(phi_start)
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
try:
+ print("Iteration start")
for i in range(opt_steps):
- feed_dict = {
- target: phi_target,
- input_trunc: truncation,
- }
- sess.run([train_step_z, train_step_y, loss], feed_dict)
+ sess.run([train_step_z, train_step_y, loss], feed_dict=feed_dict)
phi_guess = sess.run(output)
guess_im = imconvert_uint8(phi_guess)
@@ -169,7 +167,7 @@ def export_video(fp_frames):
shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames))
def load_target_image(opt_fp_in):
- print("Processing {}".format(opt_fp_in))
+ print("Loading {}".format(opt_fp_in))
fn = os.path.basename(opt_fp_in)
fbase, ext = os.path.splitext(fn)
fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000))