summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-08 02:02:56 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-08 02:02:56 +0100
commit4d05266f3c1f62646a5948f532528f591c6dc2ee (patch)
tree3e517d0d1f388fb8704ee1475f5e066f9ff6c396
parent2034d4c0cd241106900273980ee84f808a73d196 (diff)
up
-rw-r--r--cli/app/search/search_class.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
index 134a139..691a0a8 100644
--- a/cli/app/search/search_class.py
+++ b/cli/app/search/search_class.py
@@ -144,13 +144,11 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [batch_size, -1])
feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers)
- # Batch reconstruction error.
- inv_loss = 1.0 * mse_loss + 1.0 * feat_loss
-
+ loss = 1.0 * mse_loss + 1.0 * feat_loss
train_step_z = tf.train.AdamOptimizer(learning_rate=z_lr, beta1=0.9, beta2=0.999) \
- .minimize(inv_loss, var_list=[input_z])
+ .minimize(loss, var_list=[input_z])
train_step_y = tf.train.AdamOptimizer(learning_rate=y_lr, beta1=0.9, beta2=0.999) \
- .minimize(inv_loss, var_list=[input_y])
+ .minimize(loss, var_list=[input_y])
else:
loss = tf.compat.v1.losses.mean_squared_error(target, output)