summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-21 15:15:34 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-21 15:15:34 +0100
commit62ca16040e9f456c24bbde3a6e4cad6f31919b01 (patch)
tree8357d1ed3b91ae7e97cb82d28657f0b236d0425f
parent310d105be45736e46d70bbd6aa6e9e1752c22d0d (diff)
reinit optimizer
-rw-r--r--cli/app/search/search_class.py23
1 files changed, 16 insertions, 7 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
index 667afae..fb497f3 100644
--- a/cli/app/search/search_class.py
+++ b/cli/app/search/search_class.py
@@ -152,17 +152,24 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
z_lr = 0.001
y_lr = 0.001
- train_step_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) \
- .minimize(loss, var_list=[input_z])
- train_step_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) \
- .minimize(loss, var_list=[input_y])
-
+ optimizer_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999)
+ train_step_z = optimizer_z.minimize(loss, var_list=[input_z])
+ optimizer_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999)
+ train_step_y = optimizer_y.minimize(loss, var_list=[input_y])
+ reinit_optimizer_z = tf.variables_initializer(optimizer_z.variables())
+ reinit_optimizer_y = tf.variables_initializer(optimizer_y.variables())
else:
z_lr = 0.001
y_lr = 0.001
loss = tf.compat.v1.losses.mean_squared_error(target, output)
- train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
- train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
+ optimizer_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999)
+ train_step_z = optimizer_z.minimize(loss, var_list=[input_z])
+ optimizer_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999)
+ train_step_y = optimizer_y.minimize(loss, var_list=[input_y])
+ reinit_optimizer_z = tf.variables_initializer(optimizer_z.variables())
+ reinit_optimizer_y = tf.variables_initializer(optimizer_y.variables())
+ # train_step_z = tf.train.AdamOptimizer(z_lr).minimize(loss, var_list=[input_z], name='AdamOpterZ')
+ # train_step_y = tf.train.AdamOptimizer(y_lr).minimize(loss, var_list=[input_y], name='AdamOpterY')
target_im, fp_frames, fn_base = load_target_image(opt_fp_in, opt_video)
@@ -195,9 +202,11 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
if i > 0:
if opt_stochastic_clipping and (i % opt_clip_interval) == 0: # and i < opt_steps * 0.45:
sess.run(clip_latent, { clipped_alpha: 0.0 })
+ sess.run(reinit_optimizer_z)
if opt_label_clipping and (i % opt_clip_interval) == 0: # and i < opt_steps * 0.75:
# sess.run(clip_labels, { normalized_alpha: (i / opt_steps) ** 2 })
sess.run(clip_labels, { normalized_alpha: 0.0 })
+ sess.run(reinit_optimizer_y)
if opt_video and opt_snapshot_interval != 0 and (i % opt_snapshot_interval) == 0:
phi_guess = sess.run(output)
guess_im = imgrid(imconvert_uint8(phi_guess), cols=1)