summaryrefslogtreecommitdiff
path: root/inversion/segmentation.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/segmentation.py')
-rw-r--r--inversion/segmentation.py191
1 files changed, 191 insertions, 0 deletions
diff --git a/inversion/segmentation.py b/inversion/segmentation.py
new file mode 100644
index 0000000..42f2f9e
--- /dev/null
+++ b/inversion/segmentation.py
@@ -0,0 +1,191 @@
+# ------------------------------------------------------------------------------
+# Cluster the attention map of inverted images for unsupervised segmentation.
+# ------------------------------------------------------------------------------
+
+import h5py
+import numpy as np
+import os
+import params
+import scipy
+import scipy.cluster.hierarchy
+from sklearn.cluster import AgglomerativeClustering
+import sys
+import tensorflow as tf
+import tensorflow_hub as hub
+import time
+import visualize as vs
+
+if len(sys.argv) < 2:
+ sys.exit('Must provide a configuration file.')
+
+params = params.Params(sys.argv[1])
+params.batch_size = 1
+params.sample_size = 1
+
+# --------------------------
+# Global directories.
+# --------------------------
+BATCH_SIZE = params.batch_size
+SAMPLE_SIZE = params.sample_size
+SAMPLES_DIR = 'attention'
+INVERSES_DIR = 'inverses'
+if not os.path.exists(SAMPLES_DIR):
+ os.makedirs(SAMPLES_DIR)
+
+# --------------------------
+# Util functions.
+# --------------------------
+# One hot encoding for classes.
+def one_hot(values):
+ return np.eye(N_CLASS)[values]
+
+def segment_img(diss_matrix, n_clusters):
+ # Cluster image based on the information from the attention map.
+ clustering = AgglomerativeClustering(n_clusters=n_clusters,
+ affinity='precomputed', linkage='average')
+ clustering.fit(diss_matrix)
+ labels = clustering.labels_
+
+ # Upsample segmentation (from 64x64 to 128x128) and create an image where each
+ # segment has the average color of its members.
+ labels = np.broadcast_to(labels.reshape(64, 1, 64, 1), (64, 2, 64, 2))\
+ .reshape(128*128)
+ labels = np.eye(labels.max() + 1)[labels]
+ cluster_col = np.matmul(labels.T,
+ np.transpose(_gen_img, [0, 2, 3, 1]).reshape(128*128, 3))
+ cluster_count = labels.T.sum(axis=1).reshape(-1, 1)
+ labels_img = np.matmul(labels, cluster_col) / np.matmul(labels, cluster_count)
+ labels_img = np.transpose(labels_img, [1, 0]).reshape(1,3,128,128)
+ return vs.data2img(labels_img)
+
+# --------------------------
+# Load Graph.
+# --------------------------
+generator = hub.Module(str(params.generator_path))
+
+gen_signature = 'generator'
+if 'generator' not in generator.get_signature_names():
+ gen_signature = 'default'
+
+input_info = generator.get_input_info_dict(gen_signature)
+COND_GAN = 'y' in input_info
+
+if COND_GAN:
+ Z_DIM = input_info['z'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ N_CLASS = input_info['y'].get_shape().as_list()[1]
+ label = tf.get_variable(name='label', dtype=tf.float32,
+ shape=[BATCH_SIZE, N_CLASS])
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_in['y'] = label
+ gen_img = generator(gen_in, signature=gen_signature)
+else:
+ Z_DIM = input_info['default'].get_shape().as_list()[1]
+ latent = tf.get_variable(name='latent', dtype=tf.float32,
+ shape=[BATCH_SIZE, Z_DIM])
+ if (params.generator_fixed_inputs):
+ gen_in = dict(params.generator_fixed_inputs)
+ gen_in['z'] = latent
+ gen_img = generator(gen_in, signature=gen_signature)
+ else:
+ gen_img = generator(latent, signature=gen_signature)
+
+# Convert generated image to channels_first.
+gen_img = tf.transpose(gen_img, [0, 3, 1, 2])
+
+# Override intermediate layer.
+if params.inv_layer == 'latent':
+ encoding = latent
+ ENC_SHAPE = [Z_DIM]
+else:
+ layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
+ gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
+ ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
+ encoding = tf.get_variable(name='encoding', dtype=tf.float32,
+ shape=[BATCH_SIZE,] + ENC_SHAPE)
+ tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))
+
+# Get attention map.
+att_map_name = 'module_apply_' + gen_signature + '/' + params.attention_map_layer
+att_map = tf.get_default_graph().get_tensor_by_name(att_map_name)
+
+# Define image shape.
+IMG_SHAPE = gen_img.get_shape().as_list()[1:]
+
+# --------------------------
+# Dataset.
+# --------------------------
+if params.out_dataset.endswith('.hdf5'):
+ in_file = h5py.File(os.path.join(INVERSES_DIR, params.out_dataset), 'r')
+ sample_images = in_file['xtrain']
+ if COND_GAN:
+ sample_labels = in_file['ytrain']
+ sample_latents = in_file['latent']
+ sample_encodings = in_file['encoding']
+ NUM_IMGS = sample_images.shape[0] # number of images.
+ def sample_images_gen():
+ for i in xrange(NUM_IMGS / BATCH_SIZE):
+ i_1, i_2 = i*BATCH_SIZE, (i+1)*BATCH_SIZE
+ if COND_GAN:
+ label_batch = sample_labels[i_1:i_2]
+ else:
+ label_batch = np.zeros(BATCH_SIZE)
+ yield sample_images[i_1:i_2], label_batch, sample_latents[i_1:i_2],\
+ sample_encodings[i_1:i_2]
+ image_gen = sample_images_gen()
+else:
+ sys.exit('Unknown dataset {}.'.format(params.out_dataset))
+
+NUM_IMGS -= NUM_IMGS % BATCH_SIZE
+
+# --------------------------
+# Training.
+# --------------------------
+# Start session.
+sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
+sess.run(tf.global_variables_initializer())
+sess.run(tf.tables_initializer())
+
+# Export attention map for reconstructed images.
+it = 0
+out_pos = 0
+start_time = time.time()
+
+for image_batch, label_batch, lat_batch, enc_batch in image_gen:
+
+ # Set target label.
+ if COND_GAN:
+ sess.run(label.assign(one_hot(label_batch)))
+
+ # Initialize encodings.
+ sess.run(latent.assign(lat_batch))
+ sess.run(encoding.assign(enc_batch))
+
+ # Get attention map.
+ _att_map, _gen_img = sess.run([att_map, gen_img])
+
+ # Upsampling (from 32x32 to 64x64).
+ _att_map = np.broadcast_to(_att_map.reshape(64,64,32,1,32,1),
+ (64,64,32,2,32,2)).reshape(4096,4096)
+
+ # Define dissimilarity matrix.
+ dissimilarity = 1.0 - (_att_map + _att_map.T) / 2.0
+ dissimilarity *= (np.ones((4096,4096)) - np.identity(4096))
+
+ # Segment the image with different number of clusters.
+ seg_img_8 = segment_img(dissimilarity, 8)
+ seg_img_20 = segment_img(dissimilarity, 20)
+ seg_img_40 = segment_img(dissimilarity, 40)
+
+ # Save segmentation.
+ out_batch_1 = vs.interleave(image_batch, seg_img_20)
+ out_batch_2 = vs.interleave(seg_img_8, seg_img_40)
+ out_batch = vs.interleave(out_batch_1, out_batch_2)
+ out_batch = vs.seq_transform(out_batch)
+ vs.save_image('{}/segmented_img_{}.png'.format(SAMPLES_DIR, it), out_batch)
+
+ it += 1
+
+sess.close()