diff options
Diffstat (limited to 'inversion/segmentation.py')
| -rw-r--r-- | inversion/segmentation.py | 191 |
1 files changed, 191 insertions, 0 deletions
diff --git a/inversion/segmentation.py b/inversion/segmentation.py new file mode 100644 index 0000000..42f2f9e --- /dev/null +++ b/inversion/segmentation.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------------ +# Cluster the attention map of inverted images for unsupervised segmentation. +# ------------------------------------------------------------------------------ + +import h5py +import numpy as np +import os +import params +import scipy +import scipy.cluster.hierarchy +from sklearn.cluster import AgglomerativeClustering +import sys +import tensorflow as tf +import tensorflow_hub as hub +import time +import visualize as vs + +if len(sys.argv) < 2: + sys.exit('Must provide a configuration file.') + +params = params.Params(sys.argv[1]) +params.batch_size = 1 +params.sample_size = 1 + +# -------------------------- +# Global directories. +# -------------------------- +BATCH_SIZE = params.batch_size +SAMPLE_SIZE = params.sample_size +SAMPLES_DIR = 'attention' +INVERSES_DIR = 'inverses' +if not os.path.exists(SAMPLES_DIR): + os.makedirs(SAMPLES_DIR) + +# -------------------------- +# Util functions. +# -------------------------- +# One hot encoding for classes. +def one_hot(values): + return np.eye(N_CLASS)[values] + +def segment_img(diss_matrix, n_clusters): + # Cluster image based on the information from the attention map. + clustering = AgglomerativeClustering(n_clusters=n_clusters, + affinity='precomputed', linkage='average') + clustering.fit(diss_matrix) + labels = clustering.labels_ + + # Upsample segmentation (from 64x64 to 128x128) and create an image where each + # segment has the average color of its members. + labels = np.broadcast_to(labels.reshape(64, 1, 64, 1), (64, 2, 64, 2))\ + .reshape(128*128) + labels = np.eye(labels.max() + 1)[labels] + cluster_col = np.matmul(labels.T, + np.transpose(_gen_img, [0, 2, 3, 1]).reshape(128*128, 3)) + cluster_count = labels.T.sum(axis=1).reshape(-1, 1) + labels_img = np.matmul(labels, cluster_col) / np.matmul(labels, cluster_count) + labels_img = np.transpose(labels_img, [1, 0]).reshape(1,3,128,128) + return vs.data2img(labels_img) + +# -------------------------- +# Load Graph. +# -------------------------- +generator = hub.Module(str(params.generator_path)) + +gen_signature = 'generator' +if 'generator' not in generator.get_signature_names(): + gen_signature = 'default' + +input_info = generator.get_input_info_dict(gen_signature) +COND_GAN = 'y' in input_info + +if COND_GAN: + Z_DIM = input_info['z'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + N_CLASS = input_info['y'].get_shape().as_list()[1] + label = tf.get_variable(name='label', dtype=tf.float32, + shape=[BATCH_SIZE, N_CLASS]) + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_in['y'] = label + gen_img = generator(gen_in, signature=gen_signature) +else: + Z_DIM = input_info['default'].get_shape().as_list()[1] + latent = tf.get_variable(name='latent', dtype=tf.float32, + shape=[BATCH_SIZE, Z_DIM]) + if (params.generator_fixed_inputs): + gen_in = dict(params.generator_fixed_inputs) + gen_in['z'] = latent + gen_img = generator(gen_in, signature=gen_signature) + else: + gen_img = generator(latent, signature=gen_signature) + +# Convert generated image to channels_first. +gen_img = tf.transpose(gen_img, [0, 3, 1, 2]) + +# Override intermediate layer. +if params.inv_layer == 'latent': + encoding = latent + ENC_SHAPE = [Z_DIM] +else: + layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer + gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name) + ENC_SHAPE = gen_encoding.get_shape().as_list()[1:] + encoding = tf.get_variable(name='encoding', dtype=tf.float32, + shape=[BATCH_SIZE,] + ENC_SHAPE) + tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding)) + +# Get attention map. +att_map_name = 'module_apply_' + gen_signature + '/' + params.attention_map_layer +att_map = tf.get_default_graph().get_tensor_by_name(att_map_name) + +# Define image shape. +IMG_SHAPE = gen_img.get_shape().as_list()[1:] + +# -------------------------- +# Dataset. +# -------------------------- +if params.out_dataset.endswith('.hdf5'): + in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r') + sample_images = in_file['xtrain'] + if COND_GAN: + sample_labels = in_file['ytrain'] + sample_latents = in_file['latent'] + sample_encodings = in_file['encoding'] + NUM_IMGS = sample_images.shape[0] # number of images. + def sample_images_gen(): + for i in xrange(NUM_IMGS / BATCH_SIZE): + i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE + if COND_GAN: + label_batch = sample_labels[i_1:i_2] + else: + label_batch = np.zeros(BATCH_SIZE) + yield sample_images[i_1:i_2], label_batch, sample_latents[i_1:i_2],\ + sample_encodings[i_1:i_2] + image_gen = sample_images_gen() +else: + sys.exit('Unknown dataset {}.'.format(params.out_dataset)) + +NUM_IMGS -= NUM_IMGS % BATCH_SIZE + +# -------------------------- +# Training. +# -------------------------- +# Start session. +sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) +sess.run(tf.global_variables_initializer()) +sess.run(tf.tables_initializer()) + +# Export attention map for reconstructed images. +it = 0 +out_pos = 0 +start_time = time.time() + +for image_batch, label_batch, lat_batch, enc_batch in image_gen: + + # Set target label. + if COND_GAN: + sess.run(label.assign(one_hot(label_batch))) + + # Initialize encodings. + sess.run(latent.assign(lat_batch)) + sess.run(encoding.assign(enc_batch)) + + # Get attention map. + _att_map, _gen_img = sess.run([att_map, gen_img]) + + # Upsampling (from 32x32 to 64x64). + _att_map = np.broadcast_to(_att_map.reshape(64,64,32,1,32,1), + (64,64,32,2,32,2)).reshape(4096,4096) + + # Define dissimilarity matrix. + dissimilarity = 1.0 - (_att_map + _att_map.T) / 2.0 + dissimilarity *= (np.ones((4096,4096)) - np.identity(4096)) + + # Segment the image with different number of clusters. + seg_img_8 = segment_img(dissimilarity, 8) + seg_img_20 = segment_img(dissimilarity, 20) + seg_img_40 = segment_img(dissimilarity, 40) + + # Save segmentation. + out_batch_1 = vs.interleave(image_batch, seg_img_20) + out_batch_2 = vs.interleave(seg_img_8, seg_img_40) + out_batch = vs.interleave(out_batch_1, out_batch_2) + out_batch = vs.seq_transform(out_batch) + vs.save_image('{}/segmented_img_{}.png'.format(SAMPLES_DIR, it), out_batch) + + it += 1 + +sess.close() |
