summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--inversion/image_sample.py96
-rw-r--r--inversion/visualize.py8
2 files changed, 55 insertions, 49 deletions
diff --git a/inversion/image_sample.py b/inversion/image_sample.py
index 83622a1..b13cedf 100644
--- a/inversion/image_sample.py
+++ b/inversion/image_sample.py
@@ -8,10 +8,13 @@ import os
import sys
import tensorflow as tf
import tensorflow_hub as hub
+import tensorflow_probability as tfp
import time
+import random
import visualize as vs
import argparse
from glob import glob
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
# --------------------------
# Hyper-parameters.
@@ -27,12 +30,11 @@ from glob import glob
# custom_label: custom label to be fixed.
# Logging:
# sample_size: number of images included in sampled images.
-if len(sys.argv) < 2:
- sys.exit('Must provide a configuration file.')
parser = argparse.ArgumentParser(description='Initialize the image search.')
parser.add_argument('--input_dir', required=True, help='Input directory of images')
parser.add_argument('--tag', default=str(int(time.time())), help='Tag this build')
+parser.add_argument('--iterations', type=int, default=1000, help='Number of iterations to find vector')
params = parser.parse_args()
# --------------------------
@@ -43,16 +45,24 @@ BATCH_SIZE = 20
SAMPLE_SIZE = 20
assert SAMPLE_SIZE <= BATCH_SIZE
-INVERSION_ITERATIONS = 1000
-
# --------------------------
# Global directories.
# --------------------------
DATASET_OUT = "{}_dataset.hdf5".format(params.tag)
SAMPLES_DIR = './outputs/{}/samples'.format(params.tag)
INVERSES_DIR = './outputs/{}/inverses'.format(params.tag)
+LOGS_DIR = './outputs/{}/logs'.format(params.tag)
os.makedirs(SAMPLES_DIR, exist_ok=True)
os.makedirs(INVERSES_DIR, exist_ok=True)
+os.makedirs(LOGS_DIR, exist_ok=True)
+
+# --------------------------
+# Logging.
+# --------------------------
+summary_writer = tf.summary.FileWriter(LOGS_DIR)
+def log_stats(name, val, it):
+ summary = tf.Summary(value=[tf.Summary.Value(tag=name, simple_value=val)])
+ summary_writer.add_summary(summary, it)
# --------------------------
# Load Graph.
@@ -96,7 +106,7 @@ target_img = (tf.cast(target, tf.float32) / 255.) * 2.0 - 1. # Norm to [-1, 1].
# Monitor relu's activation.
gen_scope = 'module_apply_' + gen_signature + '/'
activation_rate = 1.0 - tf.nn.zero_fraction(tf.get_default_graph()\
- .get_tensor_by_name(gen_scope + params.log_activation_layer))
+ .get_tensor_by_name(gen_scope + "Generator_2/GBlock/Relu:0"))
# --------------------------
# Reconstruction losses.
@@ -119,9 +129,9 @@ gen_img_1 = tf.image.resize_images(gen_img_1, [height, width])
target_img_1 = tf.image.resize_images(target_img_1, [height, width])
gen_feat = feature_extractor(dict(images=gen_img_1), as_dict=True,
- signature='image_feature_vector')[params.feature_extractor_output]
+ signature='image_feature_vector')['InceptionV3/Mixed_7a']
target_feat = feature_extractor(dict(images=target_img_1), as_dict=True,
- signature='image_feature_vector')[params.feature_extractor_output]
+ signature='image_feature_vector')['InceptionV3/Mixed_7a']
feat_square_diff = tf.reshape(tf.square(gen_feat - target_feat),
[BATCH_SIZE, -1])
feat_loss = tf.reduce_mean(feat_square_diff)
@@ -148,7 +158,7 @@ inv_loss = rec_loss + 0.1 * likeli_loss
# Optimizer.
# --------------------------
lrate = tf.train.exponential_decay(0.1, inv_step,
- INVERSION_ITERATIONS / 2, 0.1, staircase=True)
+ params.iterations / 2, 0.1, staircase=True)
trained_params = [latent, label]
optimizer = tf.train.AdamOptimizer(learning_rate=lrate, beta1=0.9, beta2=0.999)
inv_train_op = optimizer.minimize(inv_loss, var_list=trained_params,
@@ -161,8 +171,18 @@ reinit_optimizer = tf.variables_initializer(optimizer.variables())
def noise_sampler():
return np.random.normal(size=[BATCH_SIZE, Z_DIM])
-def label_init(shape=[BATCH_SIZE, N_CLASS]):
- return np.random.uniform(low=0.00, high=0.01, size=shape)
+#def label_sampler(shape=[BATCH_SIZE, N_CLASS]):
+# return np.random.uniform(low=0.00, high=0.01, size=shape)
+
+def label_sampler(shape=[BATCH_SIZE, N_CLASS]):
+ num_classes = 2
+ label = np.zeros(shape)
+ for i in range(shape[0]):
+ for _ in range(random.randint(1, shape[1])):
+ j = random.randint(0, shape[1]-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
# --------------------------
# Dataset.
@@ -172,7 +192,7 @@ paths = glob(os.path.join(params.input_dir, '*.jpg')) + \
glob(os.path.join(params.input_dir, '*.jpeg')) + \
glob(os.path.join(params.input_dir, '*.png'))
sample_images = [ vs.load_image(fn, 128) for fn in sorted(paths) ]
-ACTUAL_NUM_IMGS = sample_images.shape[0] # number of images to be inverted.
+ACTUAL_NUM_IMGS = len(sample_images) # number of images to be inverted.
print("Number of images: {}".format(ACTUAL_NUM_IMGS))
NUM_IMGS = ACTUAL_NUM_IMGS
@@ -181,6 +201,8 @@ while NUM_IMGS % BATCH_SIZE != 0:
sample_images += sample_images[-1]
NUM_IMGS += 1
sample_images = np.array(sample_images)
+print(sample_images.shape)
+#sample_images = np.reshape(sample_images, (sample_images.shape[0],3,128,128,))
def sample_images_gen():
for i in range(int(NUM_IMGS / BATCH_SIZE)):
@@ -201,11 +223,12 @@ sess.run(tf.tables_initializer())
# Output file.
out_file = h5py.File(os.path.join(INVERSES_DIR, DATASET_OUT), 'w')
out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE, dtype='uint8')
-out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')
+out_labels = out_file.create_dataset('ytrain', [NUM_IMGS, N_CLASS], dtype='float32')
# Gradient descent w.r.t. generator's inputs.
it = 0
out_pos = 0
+count = 0
start_time = time.time()
for image_batch in image_gen:
@@ -213,8 +236,8 @@ for image_batch in image_gen:
sess.run(target.assign(image_batch))
# Start with a random label
- label_batch = label_sampler(BATCH_SIZE)
- sess.run(label.assign(label_init()))
+ #label_batch = label_sampler()
+ sess.run(label.assign(label_sampler()))
# Start with a random vector
sess.run(latent.assign(noise_sampler()))
@@ -224,7 +247,7 @@ for image_batch in image_gen:
sess.run(reinit_optimizer)
# Main optimization loop.
- for _ in range(params.inv_it):
+ for _ in range(params.iterations):
_inv_loss, _mse_loss, _feat_loss, _rec_loss, _likeli_loss,\
_lrate, _ = sess.run([inv_loss, mse_loss, feat_loss,
rec_loss, likeli_loss, lrate, inv_train_op])
@@ -245,17 +268,17 @@ for image_batch in image_gen:
log_stats('mse loss', _mse_loss, it)
log_stats('feat loss', _feat_loss, it)
log_stats('rec loss', _rec_loss, it)
- log_stats('reg loss', _reg_loss, it)
- log_stats('dist loss', _dist_loss, it)
+ log_stats('likeli loss', _likeli_loss, it)
log_stats('out pos', out_pos, it)
log_stats('lrate', _lrate, it)
summary_writer.flush()
gen_images = sess.run(gen_img)
- inv_batch = vs.interleave(image_batch[BATCH_SIZE - SAMPLE_SIZE:],
- vs.data2img(gen_images[BATCH_SIZE - SAMPLE_SIZE:]))
+ #print(gen_images.shape)
+ inv_batch = vs.interleave(vs.data2img(image_batch), vs.data2img(gen_images))
inv_batch = vs.grid_transform(inv_batch)
- vs.save_image('{}/progress_{}.png'.format(SAMPLES_DIR, it), inv_batch)
+ #print(inv_batch.shape)
+ vs.save_image('{}/initial_{:03d}_{:05d}.png'.format(SAMPLES_DIR, count, it), inv_batch)
it += 1
@@ -264,32 +287,11 @@ for image_batch in image_gen:
label_batch = sess.run(label)
print(label_batch.shape)
- out_images[i:i+BATCH_SIZE] = image_batch
- out_labels[i:i+BATCH_SIZE] = label_batch
-
- out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
- vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
- print('Saved samples for imgs: {}-{}.'.format(i,i+BATCH_SIZE))
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+ out_images[count:count+BATCH_SIZE] = image_batch
+ out_labels[count:count+BATCH_SIZE] = label_batch
+ out_batch = vs.grid_transform(vs.data2img(gen_images))
+ vs.save_image('{}/initial_generated_{}.png'.format(SAMPLES_DIR, count), out_batch)
+ print('Saved samples for imgs: {}-{}.'.format(count,count+BATCH_SIZE))
+ count += BATCH_SIZE
diff --git a/inversion/visualize.py b/inversion/visualize.py
index 07aea2d..d17fe13 100644
--- a/inversion/visualize.py
+++ b/inversion/visualize.py
@@ -55,7 +55,7 @@ def imread(filename):
def imconvert_float32(im):
im = np.float32(im)
- im = (im / 256) * 2.0 - 1
+ im = (im / 255) * 2.0 - 1
return im
def load_image(opt_fp_in, opt_dims=128):
@@ -84,5 +84,9 @@ def load_image(opt_fp_in, opt_dims=128):
phi_target = phi_target[y0:y1,x0:x1]
if phi_target.shape[2] == 4:
phi_target = phi_target[:,:,1:4]
- phi_target = np.expand_dims(phi_target, 0)
+ b = np.dsplit(phi_target, 3)
+ phi_target = np.stack(b).reshape((3,opt_dims, opt_dims))
+ #print(phi_target.shape)
+ #phi_target = np.expand_dims(phi_target, 0)
+ #phi_target = np.reshape(3, opt_dims, opt_dims)
return phi_target