From fba1ca15295914e2d04bd91a1bcf5f23a68725e0 Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Tue, 18 Feb 2020 18:24:20 +0100 Subject: simple sum --- cli/app/search/search_dense.py | 134 +++++++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 59 deletions(-) (limited to 'cli/app/search/search_dense.py') diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py index c1a8426..ed22557 100644 --- a/cli/app/search/search_dense.py +++ b/cli/app/search/search_dense.py @@ -200,60 +200,70 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # Use custom features for image comparison. if params.features: - if 'http' in params.feature_extractor_path: - feature_extractor = hub.Module(str(params.feature_extractor_path)) - feature_loss = feature_loss_tfhub - height, width = hub.get_expected_image_size(feature_extractor) - elif 'vgg' in params.feature_extractor_path: - if params.feature_extractor_path == 'vgg_16': - model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt') - feature_extractor = slim.nets.vgg.vgg_16 - # conv1_1, conv1_2, conv3_2, conv4_2 - opt_feature_layers = [ - 'conv1/conv1_1', - 'conv1/conv1_2', - 'conv3/conv3_2', - 'conv4/conv4_2', - ] - feature_loss = feature_loss_vgg - height = 224 - width = 224 - else: - print("Unknown feature extractor") - return - else: - print("Unknown feature extractor") - return - # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_ch = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_ch = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) - # Convert images to appropriate size for feature extraction. img_w = 512 - feat_loss, img_feat_err = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width) - - feat_loss_a, feat_err_a = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, height, width) - feat_loss_b, feat_err_b = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, 0, height, width) - feat_loss_c, feat_err_c = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w - width, height, width) - feat_loss_d, feat_err_d = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, img_w - width, height, width) - feat_loss_e, feat_err_e = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, int((img_w - width) / 2), int((img_w - width) / 2), height, width) - - # feat_loss_aa, feat_err_aa = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, img_w/3, img_w/3, height, width) - # feat_loss_ab, feat_err_ab = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, 0, img_w/3, img_w/3, height, width) - # feat_loss_ac, feat_err_ac = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, 0, img_w/3, img_w/3, height, width) - # feat_loss_ad, feat_err_ad = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*1/3, img_w/3, img_w/3, height, width) - # feat_loss_ae, feat_err_ae = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3, height, width) - # feat_loss_af, feat_err_af = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3, height, width) - # feat_loss_ag, feat_err_ag = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*2/3, img_w/3, img_w/3, height, width) - # feat_loss_ah, feat_err_ah = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3, height, width) - # feat_loss_ai, feat_err_ai = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3, height, width) - - mse_loss_a = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w / 2, img_w / 2) - mse_loss_b = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, 0, img_w / 2, img_w / 2) - mse_loss_c = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w / 2, img_w / 2, img_w / 2) - mse_loss_d = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, img_w / 2, img_w / 2, img_w / 2) + # if 'http' in params.feature_extractor_path: + # feature_extractor = hub.Module(str(params.feature_extractor_path)) + # feature_loss = feature_loss_tfhub + # height, width = hub.get_expected_image_size(feature_extractor) + # elif 'vgg' in params.feature_extractor_path: + # if params.feature_extractor_path == 'vgg_16': + # model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt') + # feature_extractor = slim.nets.vgg.vgg_16 + # # conv1_1, conv1_2, conv3_2, conv4_2 + # opt_feature_layers = [ + # 'conv1/conv1_1', + # 'conv1/conv1_2', + # 'conv3/conv3_2', + # 'conv4/conv4_2', + # ] + # feature_loss = feature_loss_vgg + # height = 224 + # width = 224 + # else: + # print("Unknown feature extractor") + # return + # else: + # print("Unknown feature extractor") + # return + + # Inception feature extractor + + feature_extractor = hub.Module(str(params.feature_extractor_path)) + feature_loss = feature_loss_tfhub + height, width = hub.get_expected_image_size(feature_extractor) + + feat_loss_inception, img_feat_err = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width) + + # feat_loss_a, feat_err_a = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, height, width) + # feat_loss_b, feat_err_b = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, 0, height, width) + # feat_loss_c, feat_err_c = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w - width, height, width) + # feat_loss_d, feat_err_d = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, img_w - width, height, width) + # feat_loss_e, feat_err_e = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, int((img_w - width) / 2), int((img_w - width) / 2), height, width) + + model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt') + # conv1_1, conv1_2, conv3_2, conv4_2 + opt_feature_layers = [ + 'conv1/conv1_1', + 'conv1/conv1_2', + 'conv3/conv3_2', + 'conv4/conv4_2', + ] + height = 224 + width = 224 + + feat_loss_vgg, img_feat_err_vgg = feature_loss_vgg(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width) + + feat_loss = feat_loss_vgg + feat_loss_inception + + # mse_loss_a = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w / 2, img_w / 2) + # mse_loss_b = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, 0, img_w / 2, img_w / 2) + # mse_loss_c = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w / 2, img_w / 2, img_w / 2) + # mse_loss_d = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, img_w / 2, img_w / 2, img_w / 2) # mse_loss_aa = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w/3, img_w/3) # mse_loss_ab = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, 0, img_w/3, img_w/3) @@ -265,13 +275,13 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op # mse_loss_ah = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3) # mse_loss_ai = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3) - feat_loss_quad = feat_loss_a + feat_loss_b + feat_loss_c + feat_loss_d + feat_loss_e - img_feat_err_quad = feat_err_a + feat_err_b + feat_err_c + feat_err_d + feat_err_e - mse_loss_quad = mse_loss_a + mse_loss_b + mse_loss_c + mse_loss_d + # feat_loss_quad = feat_loss_a + feat_loss_b + feat_loss_c + feat_loss_d + feat_loss_e + # img_feat_err_quad = feat_err_a + feat_err_b + feat_err_c + feat_err_d + feat_err_e + # mse_loss_quad = mse_loss_a + mse_loss_b + mse_loss_c + mse_loss_d if 'vgg' in params.feature_extractor_path: variables_to_restore = slim.get_variables_to_restore(include=['vgg_16']) - print(variables_to_restore) + # print(variables_to_restore) restorer = tf.train.Saver(variables_to_restore) # feat_loss_quint = feat_loss_aa + feat_loss_ab + feat_loss_ac + feat_loss_ad + feat_loss_ae + feat_loss_af + feat_loss_ag + feat_loss_ah + feat_loss_ai @@ -426,11 +436,12 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op print("Beginning dense iteration...") for _ in range(params.inv_it): - if it < params.inv_it * 0.5: - _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, lrate, inv_train_op]) + _inv_loss, _mse_loss, _feat_loss, _feat_loss_vgg, _feat_loss_inception, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, feat_loss_vgg, feat_loss_inception, lrate, inv_train_op]) + # if it < params.inv_it * 0.5: + # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, lrate, inv_train_op]) # elif it < params.inv_it * 0.75: - else: - _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quad, mse_loss, feat_loss_quad, lrate, inv_train_op_quad]) + # else: + # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quad, mse_loss, feat_loss_quad, lrate, inv_train_op_quad]) # else: # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quint, mse_loss, feat_loss_quint, lrate, inv_train_op_quint]) @@ -441,10 +452,15 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op if it % 500 == 0: # Log losses. etime = time.time() - start_time - print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' - 'feat [{:.4f}] ' + print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' + + 'feat [{:.4f}] ' + + 'vgg [{:.4f}] ' + + 'incep [{:.4f}] ' + 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, - _feat_loss, _lrate)) + _feat_loss, + _feat_loss_vgg, + _feat_loss_inception, + _lrate)) sys.stdout.flush() -- cgit v1.2.3-70-g09d2