import glob import h5py import itertools import numpy as np from io import BytesIO import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' from PIL import Image import scipy import sys import tensorflow as tf import tensorflow_probability as tfp import tensorflow_hub as hub import time import app.search.visualize as vs tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) from app.search.params import Params from app.settings import app_cfg from app.utils.file_utils import write_pickle from app.utils.cortex_utils import upload_bytes_to_cortex feature_layer_names = { '1a': "InceptionV3/Conv2d_1a_3x3", '2a': "InceptionV3/Conv2d_2a_3x3", '2b': "InceptionV3/Conv2d_2b_3x3", '3a': "InceptionV3/Conv2d_3a_3x3", '3b': "InceptionV3/Conv2d_3b_3x3", '4a': "InceptionV3/Conv2d_4a_3x3", '5b': "InceptionV3/Mixed_5b", '5c': "InceptionV3/Mixed_5c", '5d': "InceptionV3/Mixed_5d", '6a': "InceptionV3/Mixed_6a", '6b': "InceptionV3/Mixed_6b", '6c': "InceptionV3/Mixed_6c", '6d': "InceptionV3/Mixed_6d", '6e': "InceptionV3/Mixed_6e", '7a': "InceptionV3/Mixed_7a", '7b': "InceptionV3/Mixed_7b", '7c': "InceptionV3/Mixed_7c", } def find_dense_embedding_for_images(params): # -------------------------- # Global directories. # -------------------------- LATENT_TAG = 'latent' if params.inv_layer == 'latent' else 'dense' BATCH_SIZE = params.batch_size SAMPLE_SIZE = params.sample_size LOGS_DIR = os.path.join(params.path, LATENT_TAG, 'logs') SAMPLES_DIR = os.path.join(params.path, LATENT_TAG, 'samples') os.makedirs(LOGS_DIR, exist_ok=True) os.makedirs(SAMPLES_DIR, exist_ok=True) os.makedirs(app_cfg.DIR_VECTORS, exist_ok=True) def one_hot(values): return np.eye(N_CLASS)[values] summary_writer = tf.summary.FileWriter(LOGS_DIR) def log_stats(name, val, it): summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)]) summary_writer.add_summary(summary, it) # -------------------------- # Load Graph. # -------------------------- tf.reset_default_graph() generator = hub.Module(str(params.generator_path)) gen_signature = 'generator' if 'generator' not in generator.get_signature_names(): gen_signature = 'default' input_info = generator.get_input_info_dict(gen_signature) COND_GAN = 'y' in input_info if COND_GAN: Z_DIM = input_info['z'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) N_CLASS = input_info['y'].get_shape().as_list()[1] label = tf.get_variable(name='label', dtype=tf.float32, shape=[BATCH_SIZE, N_CLASS]) gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_in['y'] = label gen_img = generator(gen_in, signature=gen_signature) else: Z_DIM = input_info['default'].get_shape().as_list()[1] latent = tf.get_variable(name='latent', dtype=tf.float32, shape=[BATCH_SIZE, Z_DIM]) if (params.generator_fixed_inputs): gen_in = dict(params.generator_fixed_inputs) gen_in['z'] = latent gen_img = generator(gen_in, signature=gen_signature) else: gen_img = generator(latent, signature=gen_signature) gen_img_orig = gen_img # Convert generated image to channels_first. gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) # Override intermediate layer. if params.inv_layer == 'latent': encoding = latent ENC_SHAPE = [Z_DIM] else: layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] encoding = tf.get_variable(name='encoding', dtype=tf.float32, shape=[BATCH_SIZE,] + ENC_SHAPE) tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) # Step counter. inv_step = tf.get_variable('inv_step', initializer=0, trainable=False) # Define target image. IMG_SHAPE = gen_img.get_shape().as_list()[1:] target = tf.get_variable(name='target', dtype=tf.float32, # normally this is the real [0-255]image shape=[BATCH_SIZE,] + IMG_SHAPE) # target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1]. target_img = target # Custom Gradient for Relus. if params.custom_grad_relu: grad_lambda = tf.train.exponential_decay(0.1, inv_step, params.inv_it / 5, 0.1, staircase=False) @tf.custom_gradient def relu_custom_grad(x): def grad(dy): return tf.where(x >= 0, dy, grad_lambda*tf.where(dy < 0, dy, tf.zeros_like(dy))) return tf.nn.relu(x), grad gen_scope = 'module_apply_' + gen_signature + '/' for op in tf.get_default_graph().get_operations(): if 'Relu' in op.name and gen_scope in op.name: assert len(op.inputs) == 1 assert len(op.outputs) == 1 new_out = relu_custom_grad(op.inputs[0]) tf.contrib.graph_editor.swap_ts(op.outputs[0], new_out) # Operations to clip the values of the encodings. if params.clipping or params.stochastic_clipping: assert params.clip >= 0 if params.stochastic_clipping: new_enc = tf.where(tf.abs(latent) >= params.clip, tf.random.uniform([BATCH_SIZE, Z_DIM], minval=-params.clip, maxval=params.clip), latent) else: new_enc = tf.clip_by_value(latent, -params.clip, params.clip) clip_latent = tf.assign(latent, new_enc) # Monitor relu's activation. if params.log_activation_layer: gen_scope = 'module_apply_' + gen_signature + '/' activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\ .get_tensor_by_name(gen_scope + params.log_activation_layer)) # -------------------------- # Reconstruction losses. # -------------------------- # Mse loss for image comparison. if params.mse: pix_square_diff = tf.square((target_img - gen_img) / 2.0) mse_loss = tf.reduce_mean(pix_square_diff) img_mse_err = tf.reduce_mean(pix_square_diff, axis=[1,2,3]) else: mse_loss = tf.constant(0.0) img_mse_err = tf.constant(0.0) # Use custom features for image comparison. if params.features: feature_extractor = hub.Module(str(params.feature_extractor_path)) # Convert images from range [-1, 1] channels_first to [0, 1] channels_last. gen_img_1 = tf.transpose(gen_img / 2.0 + 0.5, [0, 2, 3, 1]) target_img_1 = tf.transpose(target_img / 2.0 + 0.5, [0, 2, 3, 1]) # Convert images to appropriate size for feature extraction. height, width = hub.get_expected_image_size(feature_extractor) gen_img_1 = tf.image.resize_images(gen_img_1, [height, width]) target_img_1 = tf.image.resize_images(target_img_1, [height, width]) gen_feat_ex = feature_extractor(dict(images=gen_img_1), as_dict=True, signature='image_feature_vector') target_feat_ex = feature_extractor(dict(images=target_img_1), as_dict=True, signature='image_feature_vector') # gen_feat = gen_feat_ex["InceptionV3/Mixed_7a"] # target_feat = target_feat_ex["InceptionV3/Mixed_7a"] # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # feat_loss = tf.reduce_mean(feat_square_diff) * 0.334 # img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) * 0.334 # gen_feat = gen_feat_ex["InceptionV3/Mixed_7b"] # target_feat = target_feat_ex["InceptionV3/Mixed_7b"] # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # feat_loss += tf.reduce_mean(feat_square_diff) * 0.333 # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.333 # gen_feat = gen_feat_ex["InceptionV3/Mixed_7c"] # target_feat = target_feat_ex["InceptionV3/Mixed_7c"] # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # feat_loss += tf.reduce_mean(feat_square_diff) * 0.333 # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.333 # # gen_feat = gen_feat_ex["InceptionV3/Mixed_5a"] # # target_feat = target_feat_ex["InceptionV3/Mixed_5a"] # # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # # feat_loss += tf.reduce_mean(feat_square_diff) * 0.16 # # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.16 # gen_feat = gen_feat_ex["InceptionV3/Mixed_7b"] # target_feat = target_feat_ex["InceptionV3/Mixed_7b"] # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # feat_loss += tf.reduce_mean(feat_square_diff) * 0.33 # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) # # gen_feat = gen_feat_ex["InceptionV3/Mixed_7c"] # # target_feat = target_feat_ex["InceptionV3/Mixed_7c"] # # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # # feat_loss += tf.reduce_mean(feat_square_diff) * 0.17 # # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.17 # conv1 1, conv1 2, conv3 2 and conv4 2 gen_feat = gen_feat_ex["InceptionV3/Conv2d_1a_3x3"] target_feat = target_feat_ex["InceptionV3/Conv2d_1a_3x3"] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss = tf.reduce_mean(feat_square_diff) * 0.25 img_feat_err = tf.reduce_mean(feat_square_diff, axis=1) * 0.25 gen_feat = gen_feat_ex["InceptionV3/Conv2d_2a_3x3"] target_feat = target_feat_ex["InceptionV3/Conv2d_2a_3x3"] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss += tf.reduce_mean(feat_square_diff) * 0.25 img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.25 # gen_feat = gen_feat_ex["InceptionV3/Conv2d_3b_1x1"] # target_feat = target_feat_ex["InceptionV3/Conv2d_3b_1x1"] # feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) # feat_loss += tf.reduce_mean(feat_square_diff) * 0.25 # img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.25 gen_feat = gen_feat_ex["InceptionV3/Mixed_6a"] target_feat = target_feat_ex["InceptionV3/Mixed_6a"] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss += tf.reduce_mean(feat_square_diff) * 0.25 img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.25 gen_feat = gen_feat_ex["InceptionV3/Mixed_7a"] target_feat = target_feat_ex["InceptionV3/Mixed_7a"] feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat), [BATCH_SIZE, -1]) feat_loss += tf.reduce_mean(feat_square_diff) * 0.25 img_feat_err += tf.reduce_mean(feat_square_diff, axis=1) * 0.25 else: feat_loss = tf.constant(0.0) img_feat_err = tf.constant(0.0) # Per image reconstruction error. img_rec_err = params.lambda_mse * img_mse_err\ + params.lambda_feat * img_feat_err # Batch reconstruction error. rec_loss = params.lambda_mse * mse_loss + params.lambda_feat * feat_loss # Total inversion loss. inv_loss = rec_loss # -------------------------- # Optimizer. # -------------------------- if params.decay_lr: lrate = tf.train.exponential_decay(params.lr, inv_step, params.inv_it / params.decay_n, 0.1, staircase=True) else: lrate = tf.constant(params.lr) trained_params = [encoding] if params.fixed_z else [latent, encoding] optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999) inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params, global_step=inv_step) reinit_optimizer = tf.variables_initializer(optimizer.variables()) # -------------------------- # Noise source. # -------------------------- def noise_sampler(): return np.random.normal(size=[BATCH_SIZE, Z_DIM]) def small_init(shape=[BATCH_SIZE, Z_DIM]): return np.random.uniform(low=params.init_lo, high=params.init_hi, size=shape) # -------------------------- # Dataset. # -------------------------- if params.dataset.endswith('.hdf5'): in_file = h5py.File(params.dataset, 'r') sample_images = in_file['xtrain'][()] sample_labels = in_file['ytrain'][()] sample_fns = in_file['fn'][()] NUM_IMGS = sample_images.shape[0] # number of images to be inverted. print("Number of images: {}".format(NUM_IMGS)) print("Batch size: {}".format(BATCH_SIZE)) def sample_images_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_images[i_1:i_2], sample_labels[i_1:i_2] image_gen = sample_images_gen() sample_latents = in_file['latent'] def sample_latent_gen(): for i in range(int(NUM_IMGS / BATCH_SIZE)): i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE yield sample_latents[i_1:i_2] latent_gen = sample_latent_gen() if NUM_IMGS % BATCH_SIZE != 0: REMAINDER = BATCH_SIZE - (NUM_IMGS % BATCH_SIZE) NUM_IMGS += REMAINDER sample_images = np.append(sample_images, sample_images[-REMAINDER:,...], axis=0) sample_labels = np.append(sample_labels, sample_labels[-REMAINDER:,...], axis=0) sample_latents = np.append(sample_latents, sample_latents[-REMAINDER:,...], axis=0) sample_fns = np.append(sample_fns, sample_fns[-REMAINDER:], axis=0) assert(NUM_IMGS % BATCH_SIZE == 0) else: sys.exit('Unknown dataset {}.'.format(params.dataset)) # -------------------------- # Training. # -------------------------- # Start session. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) if params.max_batches > 0: NUM_IMGS_TO_PROCESS = params.max_batches * BATCH_SIZE else: NUM_IMGS_TO_PROCESS = NUM_IMGS # Output file. out_file = h5py.File(params.out_dataset, 'w') out_images = out_file.create_dataset('xtrain', [NUM_IMGS_TO_PROCESS,] + IMG_SHAPE, dtype='float32') out_enc = out_file.create_dataset('encoding', [NUM_IMGS_TO_PROCESS,] + ENC_SHAPE, dtype='float32') out_lat = out_file.create_dataset('latent', [NUM_IMGS_TO_PROCESS, Z_DIM], dtype='float32') out_fns = out_file.create_dataset('fn', [NUM_IMGS_TO_PROCESS], dtype=h5py.string_dtype()) if COND_GAN: out_labels = out_file.create_dataset('ytrain', (NUM_IMGS_TO_PROCESS, N_CLASS,), dtype='float32') out_err = out_file.create_dataset('err', (NUM_IMGS_TO_PROCESS,)) out_fns[:] = sample_fns[:NUM_IMGS_TO_PROCESS] # Gradient descent w.r.t. generator's inputs. it = 0 out_pos = 0 start_time = time.time() for image_batch, label_batch in image_gen: sess.run([ target.assign(image_batch), label.assign(label_batch), latent.assign(next(latent_gen)), inv_step.assign(0), ]) sess.run([ encoding.assign(gen_encoding), reinit_optimizer, ]) # Main optimization loop. print("Beginning dense iteration...") for _ in range(params.inv_it): _inv_loss, _mse_loss, _feat_loss,\ _lrate, _ = sess.run([inv_loss, mse_loss, feat_loss, lrate, inv_train_op]) if params.clipping or params.stochastic_clipping: sess.run(clip_latent) # Save logs with training information. if it % 500 == 0: # Log losses. etime = time.time() - start_time print('It [{:8d}] time [{:5.1f}] total [{:.4f}] mse [{:.4f}] ' 'feat [{:.4f}] ' 'lr [{:.4f}]'.format(it, etime, _inv_loss, _mse_loss, _feat_loss, _lrate)) sys.stdout.flush() # Save target images and reconstructions. if params.save_progress: assert SAMPLE_SIZE <= BATCH_SIZE gen_time = time.time() gen_images = sess.run(gen_img) print("Generation time: {:.1f}s".format(time.time() - gen_time)) inv_batch = vs.interleave(vs.data2img(image_batch[BATCH_SIZE - SAMPLE_SIZE:]), vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:])) inv_batch = vs.grid_transform(inv_batch) vs.save_image('{}/progress_{}_{:04d}.png'.format(SAMPLES_DIR, params.tag, int(it / 500)), inv_batch) it += 1 # Save images that are ready. latent_batch, enc_batch, rec_err_batch = sess.run([latent, encoding, img_rec_err]) out_lat[out_pos:out_pos+BATCH_SIZE] = latent_batch out_enc[out_pos:out_pos+BATCH_SIZE] = enc_batch out_images[out_pos:out_pos+BATCH_SIZE] = image_batch out_labels[out_pos:out_pos+BATCH_SIZE] = label_batch out_err[out_pos:out_pos+BATCH_SIZE] = rec_err_batch gen_images = sess.run(gen_img_orig) images = vs.data2img(gen_images) # write encoding, latent to pkl file for i in range(BATCH_SIZE): out_i = out_pos + i sample_fn, ext = os.path.splitext(sample_fns[out_i]) image = Image.fromarray(images[i]) fp = BytesIO() image.save(fp, format='png') data = upload_bytes_to_cortex(params.folder_id, sample_fn + "-inverse.png", fp, "image/png") if data is not None: file_id = data['id'] fp_out_pkl = os.path.join(app_cfg.DIR_VECTORS, "file_{}.pkl".format(file_id)) out_data = { 'id': file_id, 'sample_fn': sample_fn, 'label': out_labels[out_i], 'latent': out_lat[out_i], 'encoding': out_enc[out_i], } write_pickle(out_data, fp_out_pkl) out_pos += BATCH_SIZE if params.max_batches > 0 and (out_pos / BATCH_SIZE) >= params.max_batches: break print('Mean reconstruction error: {}'.format(np.mean(out_err))) print('Stdev reconstruction error: {}'.format(np.std(out_err))) print('End of inversion.') out_file.close() sess.close()