summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-02-18 18:29:03 +0100
committerJules Laplace <julescarbon@gmail.com>2020-02-18 18:29:03 +0100
commit75fa7f62aa9dbbbe3d69d03ad243a63a4b17c192 (patch)
tree6bc8c1e56a4d11387ea048bcda55405fa13843e6
parentfba1ca15295914e2d04bd91a1bcf5f23a68725e0 (diff)
simple sum
-rw-r--r--cli/app/search/search_dense.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index ed22557..f52954c 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -231,9 +231,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# print("Unknown feature extractor")
# return
+ ################################################
# Inception feature extractor
-
- feature_extractor = hub.Module(str(params.feature_extractor_path))
+ ################################################
+ # feature_extractor = hub.Module(str(params.feature_extractor_path))
+ feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
feature_loss = feature_loss_tfhub
height, width = hub.get_expected_image_size(feature_extractor)
@@ -245,6 +247,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# feat_loss_d, feat_err_d = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, img_w - width, height, width)
# feat_loss_e, feat_err_e = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, int((img_w - width) / 2), int((img_w - width) / 2), height, width)
+ ################################################
+ # VGG feature extractor
+ ################################################
model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt')
# conv1_1, conv1_2, conv3_2, conv4_2
opt_feature_layers = [
@@ -295,7 +300,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# feat_loss_quint = tf.constant(0.0)
# img_feat_err_quint = tf.constant(0.0)
- img_rec_err = params.lambda_mse * img_mse_err + params.lambda_feat * img_feat_err
+ # img_rec_err = params.lambda_mse * img_mse_err + params.lambda_feat * img_feat_err
inv_loss = (params.lambda_mse * mse_loss + params.lambda_feat * feat_loss)
inv_loss_quad = (params.lambda_mse * mse_loss_quad + params.lambda_feat * feat_loss_quad)
# inv_loss_quint = params.lambda_mse * mse_loss_quint + params.lambda_feat * feat_loss_quint