diff options
| -rw-r--r-- | cli/app/search/search_dense.py | 19 |
1 files changed, 13 insertions, 6 deletions
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index d0aef69..6126f33 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -232,14 +232,21 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op if type(opt_feature_layers) == str: opt_feature_layers = opt_feature_layers.split(',') + fixed_layers = [] + for layer in opt_feature_layers: + if ',' in layer: + fixed_layers += layer.split(',') + else: + fixed_layers.append(layer) for layer in opt_feature_layers: - layer_name = feature_layer_names[layer] - gen_feat = gen_feat_ex[layer_name] - target_feat = target_feat_ex[layer_name] - feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) - feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers) - img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) / len(opt_feature_layers) + if layer in feature_layer_names: + layer_name = feature_layer_names[layer] + gen_feat = gen_feat_ex[layer_name] + target_feat = target_feat_ex[layer_name] + feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) + feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers) + img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) / len(opt_feature_layers) # conv1 1, conv1 2, conv3 2 and conv4 2 # gen_feat = gen_feat_ex["InceptionV3/Conv2d_1a_3x3"] |
