summaryrefslogtreecommitdiff
path: root/cli/app/search/search_class.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2020-01-08 01:59:20 +0100
committerJules Laplace <julescarbon@gmail.com>2020-01-08 01:59:20 +0100
commit2034d4c0cd241106900273980ee84f808a73d196 (patch)
treee1a331d6fd0288561a5b4944eadcdcb25514ac2b /cli/app/search/search_class.py
parenta194eaa66108d753aac1eac70b7016a9b20897e1 (diff)
up
Diffstat (limited to 'cli/app/search/search_class.py')
-rw-r--r--cli/app/search/search_class.py12
1 files changed, 7 insertions, 5 deletions
diff --git a/cli/app/search/search_class.py b/cli/app/search/search_class.py
index cd53a71..134a139 100644
--- a/cli/app/search/search_class.py
+++ b/cli/app/search/search_class.py
@@ -46,7 +46,7 @@ feature_layer_names = {
def find_nearest_vector_for_images(paths, opt_dims, opt_steps, opt_video, opt_tag,
opt_limit=-1, opt_stochastic_clipping=0, opt_label_clipping=0,
- opt_use_feature_detector=False, opt_feature_layers=[1,2,4,7], opt_snapshot_interval=20):
+ opt_use_feature_detector=False, opt_feature_layers=[1,2,4,7], opt_snapshot_interval=20, opt_clip_interval=500):
tf.reset_default_graph()
sess = tf.compat.v1.Session()
print("Initializing generator...")
@@ -66,13 +66,13 @@ def find_nearest_vector_for_images(paths, opt_dims, opt_steps, opt_video, opt_ta
break
out_fns[index] = os.path.basename(path)
fp_frames = find_nearest_vector(sess, generator, path, opt_dims, out_images, out_labels, out_latent, opt_steps, index,
- opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval)
+ opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval, opt_clip_interval)
if opt_video:
export_video(fp_frames)
sess.close()
def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_labels, out_latent, opt_steps, index,
- opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval):
+ opt_stochastic_clipping, opt_label_clipping, opt_use_feature_detector, opt_feature_layers, opt_snapshot_interval, opt_clip_interval):
"""
Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
"""
@@ -122,8 +122,10 @@ def find_nearest_vector(sess, generator, opt_fp_in, opt_dims, out_images, out_la
feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1")
# Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
- gen_img_1 = tf.transpose(output / 2.0 + 0.5, [0, 2, 3, 1])
- target_img_1 = tf.transpose(target / 2.0 + 0.5, [0, 2, 3, 1])
+ # gen_img_1 = tf.transpose(output / 2.0 + 0.5, [0, 2, 3, 1])
+ # target_img_1 = tf.transpose(target / 2.0 + 0.5, [0, 2, 3, 1])
+ gen_img_1 = output / 2.0 + 0.5
+ target_img_1 = target / 2.0 + 0.5
# Convert images to appropriate size for feature extraction.
height, width = hub.get_expected_image_size(feature_extractor)