import glob import h5py import itertools import numpy as np from io import BytesIO import os import json os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from PIL import Image import scipy import sys import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import tensorflow.contrib.slim as slim import tensorflow.contrib.slim.nets as nets import time import app.search.visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from app.search.params import Params, timestamp from app.settings import app_cfg from app.utils.file_utils import write_pickle from app.utils.cortex_utils import upload_bytes_to_cortex feature_layer_names = { '1a': "InceptionV3/Conv2d_1a_3x3", '2a': "InceptionV3/Conv2d_2a_3x3", '2b': "InceptionV3/Conv2d_2b_3x3", '3a': "InceptionV3/Conv2d_3b_1x1", '3b': "InceptionV3/Conv2d_3b_1x1", '4a': "InceptionV3/Conv2d_4a_3x3", '5a': "InceptionV3/Mixed_5a", '5b': "InceptionV3/Mixed_5b", '5c': "InceptionV3/Mixed_5c", '5d': "InceptionV3/Mixed_5d", '6a': "InceptionV3/Mixed_6a", '6b': "InceptionV3/Mixed_6b", '6c': "InceptionV3/Mixed_6c", '6d': "InceptionV3/Mixed_6d", '6e': "InceptionV3/Mixed_6e", '7a': "InceptionV3/Mixed_7a", '7b': "InceptionV3/Mixed_7b", '7c': "InceptionV3/Mixed_7c", } def find_dense_embedding_for_images(params, opt_tag="inverse_" + timestamp(), opt_feature_layers=["1a,2a,4a,7a"], opt_save_progress=True): # -------------------------- # Global directories. # -------------------------- LATENT_TAG = 'latent' if params.inv_layer == 'latent' else 'dense' BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size LOGS_DIR = os.path.join(params.path, LATENT_TAG, 'logs') SAMPLES_DIR = os.path.join(params.path, LATENT_TAG, 'samples') os.makedirs(LOGS_DIR, exist_ok=True) os.makedirs(SAMPLES_DIR, exist_ok=True) os.makedirs(app_cfg.DIR_VECTORS, exist_ok=True) def one_hot(values): return np.eye(N_CLASS)[values] summary_writer = tf.summary.FileWriter(LOGS_DIR) def log_stats(name, val, it): summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) summary_writer.add_summary(summary, it) # -------------------------- # Load Graph. # -------------------------- tf.reset_default_graph() generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) gen_img_orig = gen_img # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. if params.inv_layer == 'latent': encoding = latent ENC_SHAPE = [Z_DIM] else: layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) layer_label_variables = [] gen_label = tf.get_default_graph().get_tensor_by_name('module_apply_{}/linear_1/MatMul:0'.format(gen_signature)) if params.invert_labels: op_names = [ "Generator_2/concat", "Generator_2/concat_1", "Generator_2/concat_2", "Generator_2/concat_3", "Generator_2/concat_4", "Generator_2/concat_5", "Generator_2/concat_6", ] op_input_index = 1 layer_shape = [128,] for op_name in op_names: layer_name = 'module_apply_{}/{}'.format(gen_signature, op_name) variable_name = op_name + "_label" raw_op = tf.get_default_graph().get_operation_by_name(layer_name) # new_op_input = tf.get_variable(name=variable_name, dtype=tf.float32, shape=[BATCH_SIZE,] + layer_shape) new_op_input = tf.Variable(tf.zeros([BATCH_SIZE,] + layer_shape, dtype=tf.float32), name=variable_name, trainable=True) maybe_a_tensor = new_op_input + tf.constant(0.0) raw_op._update_input(op_input_index, maybe_a_tensor) layer_label_variables.append(new_op_input) # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) # Define target image. IMG_SHAPE = gen_img.get_shape().as_list()[1:] target = tf.get_variable(name='target', dtype=tf.float32, # normally this is the real [0-255]image shape=[BATCH_SIZE,] + IMG_SHAPE) # target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. target_img = target # Custom Gradient for Relus. if params.custom_grad_relu: grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, 0.1, staircase=False) @tf.custom_gradient def relu_custom_grad(x): def grad(dy): return tf.where(x >= 0, dy, grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy))) return tf.nn.relu(x), grad gen_scope = 'module_apply_' + gen_signature + '/' for op in tf.get_default_graph().get_operations(): if 'Relu' in op.name and gen_scope in op.name: assert len(op.inputs) == 1 assert len(op.outputs) == 1 new_out = relu_custom_grad(op.inputs[0]) tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out) # Operations to clip the values of the encodings. if params.clipping or params.stochastic_clipping: assert params.clip >= 0 if params.stochastic_clipping: new_enc = tf.where(tf.abs(latent) >= params.clip, tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip, maxval=params.clip), latent) else: new_enc = tf.clip_by_value(latent, -params.clip, params.clip) clip_latent = tf.assign(latent, new_enc) # Monitor relu's activation. if params.log_activation_layer: gen_scope = 'module_apply_' + gen_signature + '/' activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ .get_tensor_by_name(gen_scope + params.log_activation_layer)) # -------------------------- # Reconstruction losses. # -------------------------- # Mse loss for image comparison. if params.mse: pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) # , axis=1) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) else: mse_loss = tf.constant(0.0) img_mse_err = tf.constant(0.0) # Use custom features for image comparison. if params.features: # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_ch = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_ch = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) img_w = 512 # if 'http' in params.feature_extractor_path: # feature_extractor = hub.Module(str(params.feature_extractor_path)) # feature_loss = feature_loss_tfhub # height, width = hub.get_expected_image_size(feature_extractor) # elif 'vgg' in params.feature_extractor_path: # if params.feature_extractor_path == 'vgg_16': # model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt') # feature_extractor = slim.nets.vgg.vgg_16 # # conv1_1, conv1_2, conv3_2, conv4_2 # opt_feature_layers = [ # 'conv1/conv1_1', # 'conv1/conv1_2', # 'conv3/conv3_2', # 'conv4/conv4_2', # ] # feature_loss = feature_loss_vgg # height = 224 # width = 224 # else: # print("Unknown feature extractor") # return # else: # print("Unknown feature extractor") # return ################################################ # Inception feature extractor ################################################ # feature_extractor = hub.Module(str(params.feature_extractor_path)) feature_extractor = hub.Module("https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1") feature_loss = feature_loss_tfhub height, width = hub.get_expected_image_size(feature_extractor) feat_loss_inception, img_feat_err = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width) # feat_loss_a, feat_err_a = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, 0, height, width) # feat_loss_b, feat_err_b = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, 0, height, width) # feat_loss_c, feat_err_c = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, 0, img_w - width, height, width) # feat_loss_d, feat_err_d = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, img_w - width, img_w - width, height, width) # feat_loss_e, feat_err_e = feature_loss(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, int((img_w - width) / 2), int((img_w - width) / 2), height, width) ################################################ # VGG feature extractor ################################################ model_path = os.path.join(app_cfg.DIR_NETS, 'vgg_16.ckpt') # conv1_1, conv1_2, conv3_2, conv4_2 opt_feature_layers = [ 'conv1/conv1_1', 'conv1/conv1_2', 'conv3/conv3_2', 'conv4/conv4_2', ] height = 224 width = 224 feat_loss_vgg, img_feat_err_vgg = feature_loss_vgg(feature_extractor, opt_feature_layers, BATCH_SIZE, gen_img_ch, target_img_ch, None, None, height, width) feat_loss = feat_loss_vgg + 10.0 * feat_loss_inception # mse_loss_a = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w / 2, img_w / 2) # mse_loss_b = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, 0, img_w / 2, img_w / 2) # mse_loss_c = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w / 2, img_w / 2, img_w / 2) # mse_loss_d = mse_loss_crop(target_img_ch, gen_img_ch, img_w / 2, img_w / 2, img_w / 2, img_w / 2) # mse_loss_aa = mse_loss_crop(target_img_ch, gen_img_ch, 0, 0, img_w/3, img_w/3) # mse_loss_ab = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, 0, img_w/3, img_w/3) # mse_loss_ac = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, 0, img_w/3, img_w/3) # mse_loss_ad = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*1/3, img_w/3, img_w/3) # mse_loss_ae = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*1/3, img_w/3, img_w/3) # mse_loss_af = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*1/3, img_w/3, img_w/3) # mse_loss_ag = mse_loss_crop(target_img_ch, gen_img_ch, 0, img_w*2/3, img_w/3, img_w/3) # mse_loss_ah = mse_loss_crop(target_img_ch, gen_img_ch, img_w*1/3, img_w*2/3, img_w/3, img_w/3) # mse_loss_ai = mse_loss_crop(target_img_ch, gen_img_ch, img_w*2/3, img_w*2/3, img_w/3, img_w/3) # feat_loss_quad = feat_loss_a + feat_loss_b + feat_loss_c + feat_loss_d + feat_loss_e # img_feat_err_quad = feat_err_a + feat_err_b + feat_err_c + feat_err_d + feat_err_e # mse_loss_quad = mse_loss_a + mse_loss_b + mse_loss_c + mse_loss_d if 'vgg' in params.feature_extractor_path: variables_to_restore = slim.get_variables_to_restore(include=['vgg_16']) # print(variables_to_restore) restorer = tf.train.Saver(variables_to_restore) # feat_loss_quint = feat_loss_aa + feat_loss_ab + feat_loss_ac + feat_loss_ad + feat_loss_ae + feat_loss_af + feat_loss_ag + feat_loss_ah + feat_loss_ai # img_feat_err_quint = feat_err_aa + feat_err_ab + feat_err_ac + feat_err_ad + feat_err_ae + feat_err_af + feat_err_ag + feat_err_ah + feat_err_ai # mse_loss_quint = mse_loss_aa + mse_loss_ab + mse_loss_ac + mse_loss_ad + mse_loss_ae + mse_loss_af + mse_loss_ag + mse_loss_ah + mse_loss_ai else: feat_loss = tf.constant(0.0) img_feat_err = tf.constant(0.0) feat_loss_quad = tf.constant(0.0) img_feat_err_quad = tf.constant(0.0) # feat_loss_quint = tf.constant(0.0) # img_feat_err_quint = tf.constant(0.0) # img_rec_err = params.lambda_mse * img_mse_err + params.lambda_feat * img_feat_err inv_loss = 100.0 * mse_loss + feat_loss # inv_loss_quad = (params.lambda_mse * mse_loss_quad + params.lambda_feat * feat_loss_quad) # inv_loss_quint = params.lambda_mse * mse_loss_quint + params.lambda_feat * feat_loss_quint # -------------------------- # Optimizer. # -------------------------- if params.decay_lr: lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it, 0.9) # lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it / params.decay_n, 0.1, staircase=True) else: lrate = tf.constant(params.lr) # trained_params = [label, latent, encoding] # trained_params = [latent, encoding] if params.inv_layer == 'latent': trained_params = [latent] else: trained_params = [latent, encoding] if params.invert_labels: trained_params += layer_label_variables optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) reinit_optimizer = tf.variables_initializer(optimizer.variables()) # optimizer_quad = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) # inv_train_op_quad = optimizer_quad.minimize(inv_loss_quad, var_list=trained_params, global_step=inv_step) # reinit_optimizer_quad = tf.variables_initializer(optimizer_quad.variables()) # optimizer_quint = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) # inv_train_op_quint = optimizer_quint.minimize(inv_loss_quint, var_list=trained_params, global_step=inv_step) # reinit_optimizer_quint = tf.variables_initializer(optimizer_quint.variables()) # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) def small_init(shape=[BATCH_SIZE, Z_DIM]): return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape) # -------------------------- # Dataset. # -------------------------- if params.dataset.endswith('.hdf5'): in_file = h5py.File(params.dataset, 'r') sample_images = in_file['xtrain'][()] sample_labels = in_file['ytrain'][()] sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. INFILL_IMGS = NUM_IMGS print("Number of images: {}".format(NUM_IMGS)) print("Batch size: {}".format(BATCH_SIZE)) def sample_images_gen(): for i in range(int(INFILL_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_images[i_1:i_2], sample_labels[i_1:i_2] image_gen = sample_images_gen() sample_latents = in_file['latent'] def sample_latent_gen(): for i in range(int(INFILL_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_latents[i_1:i_2] latent_gen = sample_latent_gen() while INFILL_IMGS % BATCH_SIZE != 0: REMAINDER = 1 # BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) INFILL_IMGS += REMAINDER sample_images = np.append(sample_images, sample_images[-REMAINDER:,...], axis=0) sample_labels = np.append(sample_labels, sample_labels[-REMAINDER:,...], axis=0) sample_latents = np.append(sample_latents, sample_latents[-REMAINDER:,...], axis=0) sample_fns = np.append(sample_fns, sample_fns[-REMAINDER:], axis=0) assert(INFILL_IMGS % BATCH_SIZE == 0) else: sys.exit('Unknown dataset {}.'.format(params.dataset)) # -------------------------- # Training. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) if 'vgg' in params.feature_extractor_path: restorer.restore(sess, model_path) if params.max_batches > 0: NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE else: NUM_IMGS_TO_PROCESS = NUM_IMGS # Output file. out_file = h5py.File(params.out_dataset, 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS_TO_PROCESS,] + IMG_SHAPE, dtype='float32') out_enc = out_file.create_dataset('encoding', [NUM_IMGS_TO_PROCESS,] + ENC_SHAPE, dtype='float32') out_lat = out_file.create_dataset('latent', [NUM_IMGS_TO_PROCESS, Z_DIM], dtype='float32') out_fns = out_file.create_dataset('fn', [NUM_IMGS_TO_PROCESS], dtype=h5py.string_dtype()) if COND_GAN: out_labels = out_file.create_dataset('ytrain', (NUM_IMGS_TO_PROCESS, N_CLASS,), dtype='float32') out_err = out_file.create_dataset('err', (NUM_IMGS_TO_PROCESS,)) out_fns[:] = sample_fns[:NUM_IMGS_TO_PROCESS] # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 start_time = time.time() for image_batch, label_batch in image_gen: sess.run([ target.assign(image_batch), label.assign(label_batch), latent.assign(next(latent_gen)), inv_step.assign(0), ]) encoding_init_funcs = [ reinit_optimizer, # reinit_optimizer_quad, # reinit_optimizer_quint, ] if params.inv_layer != 'latent': encoding_init_funcs += [ encoding.assign(gen_encoding), ] if params.invert_labels: for layer_label in layer_label_variables: encoding_init_funcs.append(layer_label.assign(gen_label)) sess.run(encoding_init_funcs) # Main optimization loop. print("Beginning dense iteration...") for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss, _feat_loss_vgg, _feat_loss_inception, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, feat_loss_vgg, feat_loss_inception, lrate, inv_train_op]) # if it < params.inv_it * 0.5: # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, lrate, inv_train_op]) # elif it < params.inv_it * 0.75: # else: # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quad, mse_loss, feat_loss_quad, lrate, inv_train_op_quad]) # else: # _inv_loss, _mse_loss, _feat_loss, _lrate, _ = sess.run([inv_loss_quint, mse_loss, feat_loss_quint, lrate, inv_train_op_quint]) if params.clipping or params.stochastic_clipping: sess.run(clip_latent) # Save logs with training information. if it % 500 == 0: # Log losses. etime = time.time() - start_time print(('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' + 'feat [{:.4f}] ' + 'vgg [{:.4f}] ' + 'incep [{:.4f}] ' + 'lr [{:.4f}]').format(it, etime, _inv_loss, _mse_loss, _feat_loss, _feat_loss_vgg, _feat_loss_inception, _lrate)) sys.stdout.flush() # Save target images and reconstructions. if opt_save_progress: assert SAMPLE_SIZE <= BATCH_SIZE # gen_time = time.time() gen_images = sess.run(gen_img) # print("Generation time: {:.1f}s".format(time.time() - gen_time)) inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}_{:04d}.png'.format(SAMPLES_DIR, opt_tag, int(it / 500)), inv_batch) it += 1 # Save images that are ready. label_trained, latent_trained = sess.run([label, latent]) if params.inv_layer != 'latent': enc_trained = sess.run(encoding) if params.invert_labels: layer_labels_trained = sess.run(layer_label_variables) gen_images = sess.run(gen_img_orig) images = vs.data2img(gen_images) # write encoding, latent to pkl file for i in range(BATCH_SIZE): out_i = out_pos + i if out_i >= NUM_IMGS: print("{} >= {}, skipping...".format(out_i, NUM_IMGS)) continue sample_fn, ext = os.path.splitext(sample_fns[out_i]) image = Image.fromarray(images[i]) fp = BytesIO() image.save(fp, format='png') data = upload_bytes_to_cortex(params.folder_id, "{}-{}.png".format(sample_fn, opt_tag), fp, "image/png") print(json.dumps(data, indent=2)) if data is not None and 'files' in data: file_id = data['files'][0]['id'] fp_out_pkl = os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(file_id)) out_data = { 'id': file_id, 'folder_id': params.folder_id, 'sample_fn': sample_fn, 'label': label_trained[i], 'latent': latent_trained[i], } if params.inv_layer != 'latent': out_data['encoding'] = enc_trained[i] if params.invert_labels: out_data['layer_labels'] = [] for layer in layer_labels_trained: out_data['layer_labels'].append(layer[i]) write_pickle(out_data, fp_out_pkl) out_lat[out_i] = latent_trained[i] if params.inv_layer != 'latent': out_enc[out_i] = enc_trained[i] out_images[out_i] = image_batch[i] out_labels[out_i] = label_trained[i] out_pos += BATCH_SIZE if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: break print('Mean reconstruction error: {}'.format(np.mean(out_err))) print('Stdev reconstruction error: {}'.format(np.std(out_err))) print('End of inversion.') out_file.close() sess.close() def mse_loss_crop(img_a, img_b, y, x, height, width): y = int(y) x = int(x) height = int(height) width = int(width) img_a = tf.image.crop_to_bounding_box(img_a, y, x, height, width) img_b = tf.image.crop_to_bounding_box(img_b, y, x, height, width) return tf.reduce_mean(tf.square((img_a - img_b) / 2.0)) def feature_loss_tfhub(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b, y, x, height, width, resize_height=None, resize_width=None): height = int(height) width = int(width) if y is not None: x = int(x) y = int(y) img_a = tf.image.crop_to_bounding_box(img_a, y, x, height, width) img_b = tf.image.crop_to_bounding_box(img_b, y, x, height, width) else: img_a = tf.image.resize_images(img_a, [height, width]) img_b = tf.image.resize_images(img_b, [height, width]) if resize_height is not None: img_a = tf.image.resize_images(img_a, [resize_height, resize_width]) img_b = tf.image.resize_images(img_b, [resize_height, resize_width]) gen_feat_ex = feature_extractor(dict(images=img_a), as_dict=True, signature='image_feature_vector') target_feat_ex = feature_extractor(dict(images=img_b), as_dict=True, signature='image_feature_vector') feat_loss = tf.constant(0.0) img_feat_err = tf.constant(0.0) if type(opt_feature_layers) == str: opt_feature_layers = opt_feature_layers.split(',') fixed_layers = [] for layer in opt_feature_layers: if ',' in layer: fixed_layers += layer.split(',') else: fixed_layers.append(layer) for layer in fixed_layers: if layer in feature_layer_names: layer_name = feature_layer_names[layer] gen_feat = gen_feat_ex[layer_name] target_feat = target_feat_ex[layer_name] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss += tf.reduce_mean(feat_square_diff) img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) return feat_loss / len(opt_feature_layers), img_feat_err / len(opt_feature_layers) # scope_index = 0 # vgg_model = tf.make_template('vgg16', nets.vgg.vgg_16, is_training=False) def feature_loss_vgg(feature_extractor, opt_feature_layers, BATCH_SIZE, img_a, img_b, y, x, height, width, resize_height=None, resize_width=None): height = int(height) width = int(width) if y is not None: x = int(x) y = int(y) img_a = tf.image.crop_to_bounding_box(img_a, y, x, height, width) img_b = tf.image.crop_to_bounding_box(img_b, y, x, height, width) else: img_a = tf.image.resize_images(img_a, [height, width]) img_b = tf.image.resize_images(img_b, [height, width]) if resize_height is not None: img_a = tf.image.resize_images(img_a, [resize_height, resize_width]) img_b = tf.image.resize_images(img_b, [resize_height, resize_width]) global scope_index # scope_index += 1 # scope_a = 'vgg_16_{}_a'.format(scope_index) # scope_b = 'vgg_16_{}_b'.format(scope_index) scope_a = 'vgg_16' scope_b = 'vgg_16' # gen_fc, gen_feat_ex = nets.vgg.vgg_16(img_a, scope=scope_a) #, reuse=True) # target_fc, target_feat_ex = nets.vgg.vgg_16(img_b, scope=scope_b) #, reuse=True) with slim.arg_scope(nets.vgg.vgg_arg_scope()): gen_fc, gen_feat_ex = nets.vgg.vgg_16(img_a) #, reuse=True) with slim.arg_scope(nets.vgg.vgg_arg_scope()): target_fc, target_feat_ex = nets.vgg.vgg_16(img_b) #, reuse=True) # gen_feat_ex = feature_extractor(dict(images=img_a), as_dict=True, signature='image_feature_vector') # target_feat_ex = feature_extractor(dict(images=img_b), as_dict=True, signature='image_feature_vector') feat_loss = tf.constant(0.0) img_feat_err = tf.constant(0.0) for layer_name in opt_feature_layers: gen_feat = gen_feat_ex[scope_a + '/' + layer_name] target_feat = target_feat_ex[scope_b + '/' + layer_name] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss += tf.reduce_mean(feat_square_diff) img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) return feat_loss / len(opt_feature_layers), img_feat_err / len(opt_feature_layers)