summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search_class.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/cli/app/commands/biggan/search_class.py b/cli/app/commands/biggan/search_class.py
index 9b9d466..58f0d86 100644
--- a/cli/app/commands/biggan/search_class.py
+++ b/cli/app/commands/biggan/search_class.py
@@ -94,12 +94,14 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
z_lr = 0.001
y_lr = 0.001
- input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, -2, 2))
- input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32, constraint=lambda t: tf.clip_by_value(t, 0, 1))
+ input_z = tf.compat.v1.Variable(z_initial, dtype=np.float32)
+ input_y = tf.compat.v1.Variable(y_initial, dtype=np.float32)
+ input_z_sigmoid = tf.compat.v1.sigmoid(input_z) * 2.0 - 1.0
+ input_y_sigmoid = tf.compat.v1.sigmoid(input_y)
input_trunc = tf.compat.v1.constant(1.0)
output = generator({
- 'z': input_z,
- 'y': input_y,
+ 'z': input_z_sigmoid,
+ 'y': input_y_sigmoid,
'truncation': input_trunc,
})
@@ -138,20 +140,19 @@ def find_nearest_vector(sess, opt_fp_in, opt_dims, out_images, out_labels, out_l
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
try:
- print("Iteration start")
+ print("Preparing to iterate...")
for i in range(opt_steps):
curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict)
if i % 20 == 0:
phi_guess = sess.run(output)
guess_im = imgrid(imconvert_uint8(phi_guess), cols=1)
- imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(int(i / 20))), guess_im)
print('iter: {}, loss: {}'.format(i, curr_loss))
except KeyboardInterrupt:
pass
- z_guess = sess.run(input_z)
- y_guess = sess.run(input_y)
+ z_guess, y_guess = sess.run([input_z_sigmoid, input_y_sigmoid])
out_labels[index] = y_guess
out_latent[index] = z_guess
return fp_frames