summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/search/json.py3
-rw-r--r--cli/app/search/search_dense.py130
-rw-r--r--cli/app/settings/app_cfg.py1
3 files changed, 98 insertions, 36 deletions
diff --git a/cli/app/search/json.py b/cli/app/search/json.py
index ecea4a9..2eeaeca 100644
--- a/cli/app/search/json.py
+++ b/cli/app/search/json.py
@@ -83,7 +83,8 @@ def make_params_dense(tag, folder_id):
"truncation": 1.0
},
"log_z_norm": False,
- "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ # "feature_extractor_path": "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1",
+ "feature_extractor_path": "vgg_16",
"mse": True,
"custom_grad_relu": False,
"random_label": False,
diff --git a/cli/app/search/search_dense.py b/cli/app/search/search_dense.py
index a35ab07..1281d64 100644
--- a/cli/app/search/search_dense.py
+++ b/cli/app/search/search_dense.py
@@ -13,6 +13,7 @@ import sys
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_hub as hub
+import tensorflow.contrib.slim as slim
import time
import app.search.visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
@@ -198,7 +199,29 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
# Use custom features for image comparison.
if params.features:
- feature_extractor = hub.Module(str(params.feature_extractor_path))
+ if 'http' in params.feature_extractor_path:
+ feature_extractor = hub.Module(str(params.feature_extractor_path))
+ feature_loss = feature_loss_tfhub
+ elif 'vgg' in params.feature_extractor_path:
+ if params.feature_extractor_path == 'vgg_16':
+ model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt')
+ feature_extractor = slim.nets.vgg.vgg_16
+ # conv1_1, conv1_2, conv3_2, conv4_2
+ opt_feature_layers = [
+ 'vgg_16/conv1/conv1_1',
+ 'vgg_16/conv1/conv1_2',
+ 'vgg_16/conv3/conv3_2',
+ 'vgg_16/conv4/conv4_2',
+ ]
+ feature_loss = feature_loss_vgg
+ else:
+ print("Unknown feature extractor")
+ return
+ variables_to_restore = slim.get_variables_to_restore()
+ restorer = tf.train.Saver(variables_to_restore)
+ else:
+ print("Unknown feature extractor")
+ return
# Convert images from range [-1, 1] channels_first to [0, 1] channels_last.
gen_img_ch = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1])
@@ -216,50 +239,50 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
feat_loss_d, feat_err_d = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, img_w - width, height, width)
feat_loss_e, feat_err_e = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, int((img_w - width) / 2), int((img_w - width) / 2), height, width)
- feat_loss_aa, feat_err_aa = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, img_w/3, img_w/3, height, width)
- feat_loss_ab, feat_err_ab = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, 0, img_w/3, img_w/3, height, width)
- feat_loss_ac, feat_err_ac = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, 0, img_w/3, img_w/3, height, width)
- feat_loss_ad, feat_err_ad = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*1/3, img_w/3, img_w/3, height, width)
- feat_loss_ae, feat_err_ae = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3, height, width)
- feat_loss_af, feat_err_af = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3, height, width)
- feat_loss_ag, feat_err_ag = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*2/3, img_w/3, img_w/3, height, width)
- feat_loss_ah, feat_err_ah = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3, height, width)
- feat_loss_ai, feat_err_ai = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3, height, width)
+ # feat_loss_aa, feat_err_aa = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, img_w/3, img_w/3, height, width)
+ # feat_loss_ab, feat_err_ab = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, 0, img_w/3, img_w/3, height, width)
+ # feat_loss_ac, feat_err_ac = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, 0, img_w/3, img_w/3, height, width)
+ # feat_loss_ad, feat_err_ad = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*1/3, img_w/3, img_w/3, height, width)
+ # feat_loss_ae, feat_err_ae = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3, height, width)
+ # feat_loss_af, feat_err_af = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3, height, width)
+ # feat_loss_ag, feat_err_ag = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w*2/3, img_w/3, img_w/3, height, width)
+ # feat_loss_ah, feat_err_ah = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3, height, width)
+ # feat_loss_ai, feat_err_ai = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3, height, width)
mse_loss_a = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w / 2, img_w / 2)
mse_loss_b = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, 0, img_w / 2, img_w / 2)
mse_loss_c = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w / 2, img_w / 2, img_w / 2)
mse_loss_d = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, img_w / 2, img_w / 2, img_w / 2)
- mse_loss_aa = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w/3, img_w/3)
- mse_loss_ab = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, 0, img_w/3, img_w/3)
- mse_loss_ac = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, 0, img_w/3, img_w/3)
- mse_loss_ad = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*1/3, img_w/3, img_w/3)
- mse_loss_ae = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3)
- mse_loss_af = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3)
- mse_loss_ag = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*2/3, img_w/3, img_w/3)
- mse_loss_ah = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3)
- mse_loss_ai = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3)
+ # mse_loss_aa = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w/3, img_w/3)
+ # mse_loss_ab = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, 0, img_w/3, img_w/3)
+ # mse_loss_ac = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, 0, img_w/3, img_w/3)
+ # mse_loss_ad = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*1/3, img_w/3, img_w/3)
+ # mse_loss_ae = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3)
+ # mse_loss_af = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3)
+ # mse_loss_ag = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*2/3, img_w/3, img_w/3)
+ # mse_loss_ah = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3)
+ # mse_loss_ai = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3)
feat_loss_quad = feat_loss_a + feat_loss_b + feat_loss_c + feat_loss_d + feat_loss_e
img_feat_err_quad = feat_err_a + feat_err_b + feat_err_c + feat_err_d + feat_err_e
mse_loss_quad = mse_loss_a + mse_loss_b + mse_loss_c + mse_loss_d
- feat_loss_quint = feat_loss_aa + feat_loss_ab + feat_loss_ac + feat_loss_ad + feat_loss_ae + feat_loss_af + feat_loss_ag + feat_loss_ah + feat_loss_ai
- img_feat_err_quint = feat_err_aa + feat_err_ab + feat_err_ac + feat_err_ad + feat_err_ae + feat_err_af + feat_err_ag + feat_err_ah + feat_err_ai
- mse_loss_quint = mse_loss_aa + mse_loss_ab + mse_loss_ac + mse_loss_ad + mse_loss_ae + mse_loss_af + mse_loss_ag + mse_loss_ah + mse_loss_ai
+ # feat_loss_quint = feat_loss_aa + feat_loss_ab + feat_loss_ac + feat_loss_ad + feat_loss_ae + feat_loss_af + feat_loss_ag + feat_loss_ah + feat_loss_ai
+ # img_feat_err_quint = feat_err_aa + feat_err_ab + feat_err_ac + feat_err_ad + feat_err_ae + feat_err_af + feat_err_ag + feat_err_ah + feat_err_ai
+ # mse_loss_quint = mse_loss_aa + mse_loss_ab + mse_loss_ac + mse_loss_ad + mse_loss_ae + mse_loss_af + mse_loss_ag + mse_loss_ah + mse_loss_ai
else:
feat_loss = tf.constant(0.0)
img_feat_err = tf.constant(0.0)
feat_loss_quad = tf.constant(0.0)
img_feat_err_quad = tf.constant(0.0)
- feat_loss_quint = tf.constant(0.0)
- img_feat_err_quint = tf.constant(0.0)
+ # feat_loss_quint = tf.constant(0.0)
+ # img_feat_err_quint = tf.constant(0.0)
img_rec_err = params.lambda_mse * img_mse_err + params.lambda_feat * img_feat_err
- inv_loss = (params.lambda_mse * mse_loss + params.lambda_feat * feat_loss) * 9
- inv_loss_quad = (params.lambda_mse * mse_loss_quad + params.lambda_feat * feat_loss_quad) * 9/4
- inv_loss_quint = params.lambda_mse * mse_loss_quint + params.lambda_feat * feat_loss_quint
+ inv_loss = (params.lambda_mse * mse_loss + params.lambda_feat * feat_loss)
+ inv_loss_quad = (params.lambda_mse * mse_loss_quad + params.lambda_feat * feat_loss_quad)
+ # inv_loss_quint = params.lambda_mse * mse_loss_quint + params.lambda_feat * feat_loss_quint
# --------------------------
# Optimizer.
@@ -288,9 +311,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
inv_train_op_quad = optimizer_quad.minimize(inv_loss_quad, var_list=trained_params, global_step=inv_step)
reinit_optimizer_quad = tf.variables_initializer(optimizer_quad.variables())
- optimizer_quint = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
- inv_train_op_quint = optimizer_quint.minimize(inv_loss_quint, var_list=trained_params, global_step=inv_step)
- reinit_optimizer_quint = tf.variables_initializer(optimizer_quint.variables())
+ # optimizer_quint = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
+ # inv_train_op_quint = optimizer_quint.minimize(inv_loss_quint, var_list=trained_params, global_step=inv_step)
+ # reinit_optimizer_quint = tf.variables_initializer(optimizer_quint.variables())
# --------------------------
# Noise source.
@@ -343,6 +366,9 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())
+ if 'vgg' in params.feature_extractor_path:
+ restorer.restore(sess, model_path)
+
if params.max_batches > 0:
NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE
else:
@@ -376,7 +402,7 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
encoding_init_funcs = [
reinit_optimizer,
reinit_optimizer_quad,
- reinit_optimizer_quint,
+ # reinit_optimizer_quint,
]
if params.inv_layer != 'latent':
@@ -396,10 +422,11 @@ def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), op
if it < params.inv_it * 0.5:
_inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, lrate, inv_train_op])
- elif it < params.inv_it * 0.75:
- _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quad, mse_loss, feat_loss_quad, lrate, inv_train_op_quad])
+ # elif it < params.inv_it * 0.75:
else:
- _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quint, mse_loss, feat_loss_quint, lrate, inv_train_op_quint])
+ _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quad, mse_loss, feat_loss_quad, lrate, inv_train_op_quad])
+ # else:
+ # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quint, mse_loss, feat_loss_quint, lrate, inv_train_op_quint])
if params.clipping or params.stochastic_clipping:
sess.run(clip_latent)
@@ -492,7 +519,7 @@ def mse_loss_crop(img_a, img_b, y, x, height, width):
img_b = tf.image.crop_to_bounding_box(img_b, y, x, height, width)
return tf.reduce_mean(tf.square((img_a - img_b) / 2.0))
-def feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b, y, x, height, width, resize_height=None, resize_width=None):
+def feature_loss_tfhub(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b, y, x, height, width, resize_height=None, resize_width=None):
height = int(height)
width = int(width)
if y is not None:
@@ -532,3 +559,36 @@ def feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b
feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers)
img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) / len(opt_feature_layers)
return feat_loss, img_feat_err
+
+
+def feature_loss_vgg(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b, y, x, height, width, resize_height=None, resize_width=None):
+ height = int(height)
+ width = int(width)
+ if y is not None:
+ x = int(x)
+ y = int(y)
+ img_a = tf.image.crop_to_bounding_box(img_a, y, x, height, width)
+ img_b = tf.image.crop_to_bounding_box(img_b, y, x, height, width)
+ else:
+ img_a = tf.image.resize_images(img_a, [height, width])
+ img_b = tf.image.resize_images(img_b, [height, width])
+
+ if resize_height is not None:
+ img_a = tf.image.resize_images(img_a, [resize_height, resize_width])
+ img_b = tf.image.resize_images(img_b, [resize_height, resize_width])
+
+ gen_fc, gen_feat_ex = slim.nets.vgg.vgg_16(img_a)
+ target_fc, target_feat_ex = slim.nets.vgg.vgg_16(img_b)
+ # gen_feat_ex = feature_extractor(dict(images=img_a), as_dict=True, signature='image_feature_vector')
+ # target_feat_ex = feature_extractor(dict(images=img_b), as_dict=True, signature='image_feature_vector')
+
+ feat_loss = tf.constant(0.0)
+ img_feat_err = tf.constant(0.0)
+
+ for layer_name in opt_feature_layers:
+ gen_feat = gen_feat_ex[layer_name]
+ target_feat = target_feat_ex[layer_name]
+ feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1])
+ feat_loss += tf.reduce_mean(feat_square_diff) / len(opt_feature_layers)
+ img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) / len(opt_feature_layers)
+ return feat_loss, img_feat_err
diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py
index 8478e5b..b42d46c 100644
--- a/cli/app/settings/app_cfg.py
+++ b/cli/app/settings/app_cfg.py
@@ -39,6 +39,7 @@ DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs')
DIR_RESULTS = join(DIR_APP, 'data_store/results')
DIR_RENDERS = join(DIR_APP, 'data_store/renders')
DIR_DISENTANGLED = join(DIR_APP, 'data_store/disentangled')
+DIR_NETS = join(DIR_APP, 'data_store/nets')
FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml')
os.makedirs(DIR_INVERSES, exist_ok=True)