summaryrefslogtreecommitdiff
path: root/cli/app/commands/biggan/search_working.py
diff options
context:
space:
mode:
authorjules@lens <julescarbon@gmail.com>2019-12-08 11:59:27 +0100
committerjules@lens <julescarbon@gmail.com>2019-12-08 11:59:27 +0100
commit326db345db13b1ab3a76406644654cb78b4d1b8d (patch)
tree4fc4f7ca8f6cf6d838332692212fcf4cc79e143f /cli/app/commands/biggan/search_working.py
parent7f60e705f230e005f033c82ddfeb0261db70d645 (diff)
biggan search test
Diffstat (limited to 'cli/app/commands/biggan/search_working.py')
-rw-r--r--cli/app/commands/biggan/search_working.py256
1 files changed, 256 insertions, 0 deletions
diff --git a/cli/app/commands/biggan/search_working.py b/cli/app/commands/biggan/search_working.py
new file mode 100644
index 0000000..0e52b17
--- /dev/null
+++ b/cli/app/commands/biggan/search_working.py
@@ -0,0 +1,256 @@
+import click
+
+from app.utils import click_utils
+from app.settings import app_cfg
+
+import os
+from os.path import join
+import time
+import numpy as np
+import random
+from scipy.stats import truncnorm
+from subprocess import call
+import cv2 as cv
+
+from PIL import Image
+
+def image_to_uint8(x):
+ """Converts [-1, 1] float array to [0, 255] uint8."""
+ x = np.asarray(x)
+ x = (256. / 2.) * (x + 1.)
+ x = np.clip(x, 0, 255)
+ x = x.astype(np.uint8)
+ return x
+
+def truncated_z_sample(batch_size, z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
+ return truncation * values
+
+def truncated_z_single(z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(1, z_dim))
+ return truncation * values
+
+def create_labels(batch_size, vocab_size, num_classes):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for _ in range(random.randint(1, num_classes)):
+ j = random.randint(0, vocab_size-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
+
+def imconvert_uint8(im):
+ im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
+ im = np.uint8(im)
+ return im
+
+def imconvert_float32(im):
+ im = np.float32(im)
+ im = (im / 256) * 2.0 - 1
+ return im
+
+def imread(filename):
+ img = cv.imread(filename, cv.IMREAD_UNCHANGED)
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return img
+
+def imwrite(filename, img):
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return cv.imwrite(filename, img)
+
+def imgrid(imarray, cols=5, pad=1):
+ if imarray.dtype != np.uint8:
+ raise ValueError('imgrid input imarray must be uint8')
+ pad = int(pad)
+ assert pad >= 0
+ cols = int(cols)
+ assert cols >= 1
+ N, H, W, C = imarray.shape
+ rows = int(np.ceil(N / float(cols)))
+ batch_pad = rows * cols - N
+ assert batch_pad >= 0
+ post_pad = [batch_pad, pad, pad, 0]
+ pad_arg = [[0, p] for p in post_pad]
+ imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
+ H += pad
+ W += pad
+ grid = (imarray
+ .reshape(rows, cols, H, W, C)
+ .transpose(0, 2, 1, 3, 4)
+ .reshape(rows*H, cols*W, C))
+ if pad:
+ grid = grid[:-pad, :-pad]
+ return grid
+
+@click.command('')
+@click.option('-i', '--input', 'opt_fp_in', required=True,
+ help='Path to input image')
+@click.option('-s', '--dims', 'opt_dims', default=128, type=int,
+ help='Dimensions of BigGAN network (128, 256, 512)')
+# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
+@click.pass_context
+def cli(ctx, opt_fp_in, opt_dims):
+ """
+ Search for an image in BigGAN using gradient descent
+ """
+ import tensorflow as tf
+ import tensorflow_hub as hub
+
+ module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2')
+ # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
+ # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+
+ inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
+ for k, v in module.get_input_info_dict().items()}
+ input_z = inputs['z']
+ input_y = inputs['y']
+ input_trunc = inputs['truncation']
+ output = module(inputs)
+
+ z_dim = input_z.shape.as_list()[1]
+ vocab_size = input_y.shape.as_list()[1]
+
+ sess = tf.compat.v1.Session()
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
+
+ # scalar truncation value in [0.02, 1.0]
+
+ batch_size = 25
+ truncation = 1.0
+
+ z = truncated_z_sample(batch_size, z_dim, truncation/2)
+
+ num_classes = 1
+ y = create_labels(batch_size, vocab_size, num_classes)
+
+ fp_frames = "frames_{}".format(int(time.time() * 1000))
+ os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
+
+ #results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation})
+ #for sample in results:
+ # sample = image_to_uint8(sample)
+ # img = Image.fromarray(sample, "RGB")
+ # fp_img_out = "{}.png".format(int(time.time() * 1000))
+ # img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out))
+
+ if opt_fp_in:
+ target_im = imread(opt_fp_in)
+ w = target_im.shape[1]
+ h = target_im.shape[0]
+ if w <= h:
+ scale = opt_dims / w
+ else:
+ scale = opt_dims / h
+ #print("{} {}".format(w, h))
+ target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale)
+ phi_target = imconvert_float32(target_im)
+ phi_target = phi_target[:opt_dims,:opt_dims]
+ if phi_target.shape[2] == 4:
+ phi_target = phi_target[:,:,1:4]
+ phi_target = np.expand_dims(phi_target, 0)
+ phi_target = np.repeat(phi_target, batch_size, axis=0)
+ else:
+ z_target = np.repeat(truncated_z_single(z_dim, truncation), batch_size, axis=0)
+ y_target = np.repeat(create_labels(1, vocab_size, 1), batch_size, axis=0)
+ feed_dict = {input_z: z_target, input_y: y_target, input_trunc: truncation}
+ phi_target = sess.run(output, feed_dict=feed_dict)
+
+ target_im = imgrid(imconvert_uint8(phi_target), cols=5)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_target.png'), target_im)
+
+ optimizer = tf.keras.optimizers.Adam()
+
+ #dy_dx = g.gradient(y, x)
+ cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
+ dc_dz, dc_dy, = tf.gradients(cost, [input_z, input_y])
+ #dc_dy, = tf.gradients(cost, [input_y])
+
+ lr_z = 0.0001
+ lr_y = 0.000001
+ #z = truncated_z_sample(batch_size, z_dim, truncation/2)
+
+ feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
+ phi_start = sess.run(output, feed_dict=feed_dict)
+ start_im = imgrid(imconvert_uint8(phi_start), cols=5)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_start.png'), start_im)
+
+ try:
+ for i in range(1000):
+ feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
+
+ #grad_z = dc_dz.eval(session=sess, feed_dict=feed_dict)
+ #grad_y = dc_dy.eval(session=sess, feed_dict=feed_dict)
+
+ grad_z, grad_y = sess.run([dc_dz, dc_dy], feed_dict=feed_dict)
+ #with tf.GradientTape(watch_accessed_variables=False, persistent=True) as g:
+ # g.watch(input_z)
+ # g.watch(input_y)
+ #cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
+ #dc_dz = g.gradient(cost, input_z)
+ #dc_dy = g.gradient(cost, input_y)
+
+ #optimizer.apply_gradients([[dc_dz, input_z], [dc_dy, input_y]])
+ #optimizer.apply_gradients([[grad_z, input_z], [grad_y, input_y]])
+ print("________")
+ #print(z[0][0:10])
+ #print(grad_y[0])
+ z -= grad_z * lr_z
+ y -= grad_y * lr_y
+
+ # decay/attenuate learning rate to 0.05 of the original over 1000 frames
+ if i > 100:
+ lr_z *= 0.997
+ if i > 500:
+ lr_y *= 0.999
+
+ indices = np.logical_or(z <= -2*truncation, z >= +2*truncation)
+ z[indices] = np.random.randn(np.count_nonzero(indices))
+ #print(z[0][0:10])
+ if i < 100:
+ if i % 30 == 0:
+ lr_z *= 1.002
+ y = np.clip(y, 0, 1)
+ for j in range(batch_size):
+ y[j] /= y[j].sum()
+ elif i < 300:
+ if i % 50 == 0:
+ lr_z *= 1.001
+ y = np.clip(y, 0, 1)
+ for j in range(batch_size):
+ y[j] /= y[j].sum()
+ elif i < 600:
+ if i % 60 == 0:
+ y = np.clip(y, 0, 1)
+ else:
+ if i % 100 == 0:
+ y = np.clip(y, 0, 1)
+
+ feed_dict = {input_z: z, input_y: y, input_trunc: truncation}
+ phi_guess = sess.run(output, feed_dict=feed_dict)
+ guess_im = imgrid(imconvert_uint8(phi_guess), cols=5)
+ imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im)
+ if i % 20 == 0:
+ print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y)))
+ #print('lr: {}, iter: {}, grad_z: {}'.format(lr, i, np.std(grad_z)))
+ #print('lr: {}, iter: {}, grad_y: {}'.format(lr, i, np.std(grad_y)))
+ except KeyboardInterrupt:
+ pass
+
+ print("Exporting video...")
+ cmd = [
+ '/home/lens/bin/ffmpeg',
+ '-y', # '-v', 'quiet',
+ '-r', '30',
+ '-i', join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_%04d.png'),
+ '-pix_fmt', 'yuv420p',
+ join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4')
+ ]
+ print(' '.join(cmd))
+ call(cmd)
+ print("Done")
+