diff options
| -rw-r--r-- | cli/app/search/search_class.py | 23 |
1 files changed, 16 insertions, 7 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py index 667afae..fb497f3 100644 --- a/cli/app/search/search_class.py +++ b/cli/app/search/search_class.py @@ -152,17 +152,24 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la z_lr = 0.001 y_lr = 0.001 - train_step_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) \ - .minimize(loss, var_list=[input_z]) - train_step_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) \ - .minimize(loss, var_list=[input_y]) - + optimizer_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) + train_step_z = optimizer_z.minimize(loss, var_list=[input_z]) + optimizer_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) + train_step_y = optimizer_y.minimize(loss, var_list=[input_y]) + reinit_optimizer_z = tf.variables_initializer(optimizer_z.variables()) + reinit_optimizer_y = tf.variables_initializer(optimizer_y.variables()) else: z_lr = 0.001 y_lr = 0.001 loss = tf.compat.v1.losses.mean_squared_error(target, output) - train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') - train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') + optimizer_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) + train_step_z = optimizer_z.minimize(loss, var_list=[input_z]) + optimizer_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) + train_step_y = optimizer_y.minimize(loss, var_list=[input_y]) + reinit_optimizer_z = tf.variables_initializer(optimizer_z.variables()) + reinit_optimizer_y = tf.variables_initializer(optimizer_y.variables()) + # train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ') + # train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY') target_im, fp_frames, fn_base = load_target_image(opt_fp_in, opt_video) @@ -195,9 +202,11 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la if i > 0: if opt_stochastic_clipping and (i % opt_clip_interval) == 0: # and i < opt_steps * 0.45: sess.run(clip_latent, { clipped_alpha: 0.0 }) + sess.run(reinit_optimizer_z) if opt_label_clipping and (i % opt_clip_interval) == 0: # and i < opt_steps * 0.75: # sess.run(clip_labels, { normalized_alpha: (i / opt_steps) ** 2 }) sess.run(clip_labels, { normalized_alpha: 0.0 }) + sess.run(reinit_optimizer_y) if opt_video and opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0: phi_guess = sess.run(output) guess_im = imgrid(imconvert_uint8(phi_guess), cols=1) |
