summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-06 15:16:36 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-06 15:16:36 +0100
commit4f625c70a4975654868c4c536def6ede472f4a93 (patch)
tree173316be66c11c56884d0cc09ecf90077c325439 /cli/app/commands
parente5430f310f6bd5430d55517c900bc83317d91048 (diff)
path
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search_class.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index bd260fb..9076e26 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -94,7 +94,8 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
- input_trunc = tf.compat.v1.constant(1.0)
+ # input_trunc = tf.compat.v1.constant(1.0)
+ input_trunc = tf.placeholder(tf.float32, shape=None)
output = generator({
'z': input_z,
'y': input_y,
@@ -123,13 +124,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
phi_target = np.repeat(phi_target, batch_size, axis=0)
# feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
- phi_start = sess.run(output)
+ phi_start = sess.run(output, {
+ input_trunc: truncation,
+ })
start_im = imconvert_uint8(phi_start)
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
try:
for i in range(opt_steps):
- sess.run([train_step_z, train_step_y])
+ feed_dict = {
+ target: phi_target,
+ input_trunc: truncation,
+ }
+ sess.run([train_step_z, train_step_y, loss], feed_dict)
phi_guess = sess.run(output)
guess_im = imconvert_uint8(phi_guess)