summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/search/search_dense.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index e65a6c9..fba393a 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -203,6 +203,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
if 'http' in params.feature_extractor_path:
feature_extractor = hub.Module(str(params.feature_extractor_path))
feature_loss = feature_loss_tfhub
+ height, width = hub.get_expected_image_size(feature_extractor)
elif 'vgg' in params.feature_extractor_path:
if params.feature_extractor_path == 'vgg_16':
model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt')
@@ -215,6 +216,8 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
'vgg_16/conv4/conv4_2',
]
feature_loss = feature_loss_vgg
+ height = 224
+ width = 224
else:
print("Unknown feature extractor")
return
@@ -229,7 +232,6 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
target_img_ch = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1])
# Convert images to appropriate size for feature extraction.
- height, width = hub.get_expected_image_size(feature_extractor)
img_w = 512
feat_loss, img_feat_err = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width)