summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-02-14 16:23:59 +0100
committerJules Laplace <julescarbon@gmail.com>2020-02-14 16:23:59 +0100
commit0a7725b4e1ee2d10e7fe6c99c643bd5c25b325c8 (patch)
tree7ba86c0e1b8d2c3f0eb44b8a1f4f5104c6223d3c /cli
parentf2fa78a23ca62db4f52486fad5a55a266e78ceb8 (diff)
vgg feature loss
Diffstat (limited to 'cli')
-rw-r--r--cli/app/search/search_dense.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index e65a6c9..fba393a 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -203,6 +203,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
if 'http' in params.feature_extractor_path:
feature_extractor = hub.Module(str(params.feature_extractor_path))
feature_loss = feature_loss_tfhub
+ height, width = hub.get_expected_image_size(feature_extractor)
elif 'vgg' in params.feature_extractor_path:
if params.feature_extractor_path == 'vgg_16':
model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt')
@@ -215,6 +216,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
'vgg_16/conv4/conv4_2',
]
feature_loss = feature_loss_vgg
+ height = 224
+ width = 224
else:
print("Unknown feature extractor")
return
@@ -229,7 +232,6 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
target_img_ch = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
# Convert images to appropriate size for feature extraction.
- height, width = hub.get_expected_image_size(feature_extractor)
img_w = 512
feat_loss, img_feat_err = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width)