summaryrefslogtreecommitdiff
path: root/cli/app/search
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-02-20 23:18:40 +0100
committerJules Laplace <julescarbon@gmail.com>2020-02-20 23:18:40 +0100
commit50d642ee747690fa56980c79d5fbc7c1b438d672 (patch)
tree861679f2909fc753babf5b2105e83bd769596672 /cli/app/search
parentb8377075977b781a7495d0fe14d834a382e74306 (diff)
vgg?
Diffstat (limited to 'cli/app/search')
-rw-r--r--cli/app/search/search_class.py71
-rw-r--r--cli/app/search/search_dense.py2
2 files changed, 49 insertions, 24 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
index 3e36a02..c74c5f7 100644
--- a/cli/app/search/search_class.py
+++ b/cli/app/search/search_class.py
@@ -25,6 +25,7 @@ from app.search.vector import truncated_z_sample, truncated_z_single, \
create_labels, create_labels_uniform
from app.search.video import export_video
from app.search.params import timestamp
+from app.search.search_dense import feature_loss_vgg
feature_layer_names = {
'1a': "InceptionV3/Conv2d_1a_3x3",
@@ -122,34 +123,58 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
print("Initializing feature detector...")
# pix_square_diff = tf.square((target - output) / 2.0)
# mse_loss = tf.reduce_mean(pix_square_diff)
- mse_loss = tf.compat.v1.losses.mean_squared_error(target, output)
+ # mse_loss = tf.compat.v1.losses.mean_squared_error(target, output)
- feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
+ # feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
- # Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
- # gen_img_1 = tf.transpose(output / 2.0 + 0.5, [0, 2, 3, 1])
- # target_img_1 = tf.transpose(target / 2.0 + 0.5, [0, 2, 3, 1])
- gen_img_1 = output / 2.0 + 0.5
- target_img_1 = target / 2.0 + 0.5
+ # # Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
+ # # gen_img_1 = tf.transpose(output / 2.0 + 0.5, [0, 2, 3, 1])
+ # # target_img_1 = tf.transpose(target / 2.0 + 0.5, [0, 2, 3, 1])
+ # gen_img_1 = output / 2.0 + 0.5
+ # target_img_1 = target / 2.0 + 0.5
- # Convert images to appropriate size for feature extraction.
- height, width = hub.get_expected_image_size(feature_extractor)
- gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
- target_img_1 = tf.image.resize_images(target_img_1, [height, width])
+ # # Convert images to appropriate size for feature extraction.
+ # height, width = hub.get_expected_image_size(feature_extractor)
+ # gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
+ # target_img_1 = tf.image.resize_images(target_img_1, [height, width])
- gen_feat_ex = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector')
- target_feat_ex = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector')
+ # gen_feat_ex = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector')
+ # target_feat_ex = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector')
- feat_loss = tf.constant(0.0)
+ # feat_loss = tf.constant(0.0)
- for layer in opt_feature_layers:
- layer_name = feature_layer_names[layer]
- gen_feat = gen_feat_ex[layer_name]
- target_feat = target_feat_ex[layer_name]
- feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [batch_size, -1])
- feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers)
+ # for layer in opt_feature_layers:
+ # layer_name = feature_layer_names[layer]
+ # gen_feat = gen_feat_ex[layer_name]
+ # target_feat = target_feat_ex[layer_name]
+ # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [batch_size, -1])
+ # feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers)
- loss = 1.0 * mse_loss + 1.0 * feat_loss
+ # loss = 1.0 * mse_loss + 1.0 * feat_loss
+
+ # gen_img_ch = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1])
+ # target_img_ch = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
+
+ pix_square_diff = tf.square((target_img - gen_img) / 2.0)
+ mse_loss = tf.reduce_mean(pix_square_diff) # , axis=1)
+
+ gen_img_ch = output # / 2.0 + 0.5
+ target_img_ch = target # / 2.0 + 0.5
+
+ model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt')
+ # conv1_1, conv1_2, conv3_2, conv4_2
+ opt_feature_layers = [
+ 'conv1/conv1_1',
+ 'conv1/conv1_2',
+ 'conv3/conv3_2',
+ 'conv4/conv4_2',
+ ]
+ height = 224
+ width = 224
+
+ feat_loss_vgg, img_feat_err_vgg = feature_loss_vgg(None, opt_feature_layers, batch_size, gen_img_ch, target_img_ch, None, None, height, width)
+
+ loss = 100.0 * mse_loss + 1.0 * feat_loss_vgg
z_lr = 0.001
y_lr = 0.001
@@ -195,11 +220,11 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
try:
print("Preparing to iterate...")
for i in range(opt_steps):
- curr_loss, _, _ = sess.run([loss, train_step_z, train_step_y], feed_dict=feed_dict)
+ curr_loss, _mse, _vgg, _, _ = sess.run([loss, mse_loss, feat_loss_vgg, train_step_z, train_step_y], feed_dict=feed_dict)
if i == 0:
print("Iterating!")
if i % 20 == 0:
- print('iter: {}, loss: {}'.format(i, curr_loss))
+ print('[it] {} [loss] {:.4f} [mse] {:.4f} [vgg] {:.4f}'.format(i, curr_loss, _mse, _vgg))
if i > 0:
if opt_stochastic_clipping and (i % opt_clip_interval) == 0: # and i < opt_steps * 0.75:
sess.run(clip_latent, { clipped_alpha: 0.0 })
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index 3c8a98c..53c548b 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -320,7 +320,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# Optimizer.
# --------------------------
if params.decay_lr:
- lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it, 0.9)
+ lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it, 0.1)
# lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it / params.decay_n, 0.1, staircase=True)
else:
lrate = tf.constant(params.lr)