summaryrefslogtreecommitdiff
path: root/cli
diff options
context:
space:
mode:
Diffstat (limited to 'cli')
-rw-r--r--cli/app/commands/biggan/search.py131
-rw-r--r--cli/app/search/image.py87
-rw-r--r--cli/app/search/util.py78
-rw-r--r--cli/app/search/vector.py20
-rw-r--r--cli/app/settings/app_cfg.py1
5 files changed, 146 insertions, 171 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index e764487..f1cf385 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -8,7 +8,6 @@ from os.path import join
import time
import numpy as np
import random
-from scipy.stats import truncnorm
from subprocess import call
import cv2 as cv
from PIL import Image
@@ -16,87 +15,22 @@ from glob import glob
import tensorflow as tf
import tensorflow_hub as hub
import shutil
+import h5py
-def image_to_uint8(x):
- """Converts [-1, 1] float array to [0, 255] uint8."""
- x = np.asarray(x)
- x = (256. / 2.) * (x + 1.)
- x = np.clip(x, 0, 255)
- x = x.astype(np.uint8)
- return x
-
-def truncated_z_sample(batch_size, z_dim, truncation):
- values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
- return truncation * values
-
-def truncated_z_single(z_dim, truncation):
- values = truncnorm.rvs(-2, 2, size=(1, z_dim))
- return truncation * values
-
-def create_labels(batch_size, vocab_size, num_classes):
- label = np.zeros((batch_size, vocab_size))
- for i in range(batch_size):
- for _ in range(random.randint(1, num_classes)):
- j = random.randint(0, vocab_size-1)
- label[i, j] = random.random()
- label[i] /= label[i].sum()
- return label
-
-def imconvert_uint8(im):
- im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
- im = np.uint8(im)
- return im
-
-def imconvert_float32(im):
- im = np.float32(im)
- im = (im / 256) * 2.0 - 1
- return im
-
-def imread(filename):
- img = cv.imread(filename, cv.IMREAD_UNCHANGED)
- if img is not None:
- if len(img.shape) > 2:
- img = img[...,::-1]
- return img
-
-def imwrite(filename, img):
- if img is not None:
- if len(img.shape) > 2:
- img = img[...,::-1]
- return cv.imwrite(filename, img)
-
-def imgrid(imarray, cols=5, pad=1):
- if imarray.dtype != np.uint8:
- raise ValueError('imgrid input imarray must be uint8')
- pad = int(pad)
- assert pad >= 0
- cols = int(cols)
- assert cols >= 1
- N, H, W, C = imarray.shape
- rows = int(np.ceil(N / float(cols)))
- batch_pad = rows * cols - N
- assert batch_pad >= 0
- post_pad = [batch_pad, pad, pad, 0]
- pad_arg = [[0, p] for p in post_pad]
- imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
- H += pad
- W += pad
- grid = (imarray
- .reshape(rows, cols, H, W, C)
- .transpose(0, 2, 1, 3, 4)
- .reshape(rows*H, cols*W, C))
- if pad:
- grid = grid[:-pad, :-pad]
- return grid
+from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \
+ imread, imwrite, imgrid
+from app.search.vector import truncated_z_sample, truncated_z_single, create_labels
@click.command('')
@click.option('-i', '--input', 'opt_fp_in', required=True,
help='Path to input image')
@click.option('-s', '--dims', 'opt_dims', default=128, type=int,
help='Dimensions of BigGAN network (128, 256, 512)')
+@click.option('-v', '--video', 'opt_video', is_flag=True,
+ help='Export a video for each dataset')
# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
@click.pass_context
-def cli(ctx, opt_fp_in, opt_dims):
+def cli(ctx, opt_fp_in, opt_dims, opt_video):
"""
Search for an image in BigGAN using gradient descent
"""
@@ -109,6 +43,9 @@ def cli(ctx, opt_fp_in, opt_dims):
input_trunc = inputs['truncation']
output = generator(inputs)
+ z_dim = input_z.shape.as_list()[1]
+ vocab_size = input_y.shape.as_list()[1]
+
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.compat.v1.tables_initializer())
@@ -117,12 +54,24 @@ def cli(ctx, opt_fp_in, opt_dims):
paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \
glob(os.path.join(opt_fp_in, '*.jpeg')) + \
glob(os.path.join(opt_fp_in, '*.png'))
- for path in paths:
- find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims)
else:
- find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims)
+ paths = [opt_fp_in]
+
+ fp_inverses = os.path.join(app_cfg.INVERSES_DIR, params.dataset_out)
+ os.makedirs(fp_inverses, exist_ok=True)
+ out_file = h5py.File(fp_inverses, 'w')
+ out_images = out_file.create_dataset('xtrain', [len(paths), 3, 512, 512], dtype='float32')
+ out_labels = out_file.create_dataset('ytrain', [len(paths), vocab_size], dtype='float32')
+
+ for path, index in enumerate(paths):
+ fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, index)
+ if opt_video:
+ export_video(fp_frames)
-def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims):
+def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, index):
+ """
+ Find the closest latent and class vectors for an image. Store the class vector in an HDF5.
+ """
z_dim = input_z.shape.as_list()[1]
vocab_size = input_y.shape.as_list()[1]
@@ -143,18 +92,14 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000))
os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
target_im = imread(opt_fp_in)
- w = target_im.shape[1]
- h = target_im.shape[0]
- if w <= h:
- scale = opt_dims / w
- else:
- scale = opt_dims / h
- #print("{} {}".format(w, h))
- target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale)
- phi_target = imconvert_float32(target_im)
- phi_target = phi_target[:opt_dims,:opt_dims]
- if phi_target.shape[2] == 4:
- phi_target = phi_target[:,:,1:4]
+ # crop image to 512 and save for later processing
+ phi_target_for_inversion = resize_and_crop_image(target_im, 512)
+ b = np.dsplit(phi_target_for_inversion, 3)
+ phi_target_for_inversion = np.stack(b).reshape((3, 512, 512))
+ out_images[index] = phi_target_for_inversion
+
+ # crop image to 128 to find vectors
+ phi_target = resize_and_crop_image(target_im, opt_dims)
phi_target = np.expand_dims(phi_target, 0)
phi_target = np.repeat(phi_target, batch_size, axis=0)
else:
@@ -211,11 +156,12 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im)
if i % 20 == 0:
print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y)))
- #print('lr: {}, iter: {}, grad_z: {}'.format(lr, i, np.std(grad_z)))
- #print('lr: {}, iter: {}, grad_y: {}'.format(lr, i, np.std(grad_y)))
except KeyboardInterrupt:
pass
+ out_labels[index] = y
+ return fp_frames
+def export_video(fp_frames):
print("Exporting video...")
cmd = [
'/home/lens/bin/ffmpeg',
@@ -225,7 +171,6 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output,
'-pix_fmt', 'yuv420p',
join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4')
]
- print(' '.join(cmd))
+ # print(' '.join(cmd))
call(cmd)
shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames))
-
diff --git a/cli/app/search/image.py b/cli/app/search/image.py
new file mode 100644
index 0000000..f800a33
--- /dev/null
+++ b/cli/app/search/image.py
@@ -0,0 +1,87 @@
+import cv2 as cv
+import numpy as np
+
+def image_to_uint8(x):
+ """Converts [-1, 1] float array to [0, 255] uint8."""
+ x = np.asarray(x)
+ x = (256. / 2.) * (x + 1.)
+ x = np.clip(x, 0, 255)
+ x = x.astype(np.uint8)
+ return x
+
+def imconvert_uint8(im):
+ im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
+ im = np.uint8(im)
+ return im
+
+def imconvert_float32(im):
+ im = np.float32(im)
+ im = (im / 256) * 2.0 - 1
+ return im
+
+def imread(filename):
+ img = cv.imread(filename, cv.IMREAD_UNCHANGED)
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return img
+
+def imwrite(filename, img):
+ if img is not None:
+ if len(img.shape) > 2:
+ img = img[...,::-1]
+ return cv.imwrite(filename, img)
+
+def imgrid(imarray, cols=5, pad=1):
+ if imarray.dtype != np.uint8:
+ raise ValueError('imgrid input imarray must be uint8')
+ pad = int(pad)
+ assert pad >= 0
+ cols = int(cols)
+ assert cols >= 1
+ N, H, W, C = imarray.shape
+ rows = int(np.ceil(N / float(cols)))
+ batch_pad = rows * cols - N
+ assert batch_pad >= 0
+ post_pad = [batch_pad, pad, pad, 0]
+ pad_arg = [[0, p] for p in post_pad]
+ imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
+ H += pad
+ W += pad
+ grid = (imarray
+ .reshape(rows, cols, H, W, C)
+ .transpose(0, 2, 1, 3, 4)
+ .reshape(rows*H, cols*W, C))
+ if pad:
+ grid = grid[:-pad, :-pad]
+ return grid
+
+def resize_and_crop_image(target_im, opt_dims):
+ w = target_im.shape[1]
+ h = target_im.shape[0]
+ if w <= h:
+ scale = opt_dims / w
+ else:
+ scale = opt_dims / h
+ #print("{} {}".format(w, h))
+ target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale)
+
+ w = target_im.shape[1]
+ h = target_im.shape[0]
+
+ x0 = 0
+ x1 = opt_dims
+ y0 = 0
+ y1 = opt_dims
+ if w > opt_dims:
+ x0 += int((w - opt_dims) / 2)
+ x1 += x0
+ if h > opt_dims:
+ y0 += int((h - opt_dims) / 2)
+ y1 += y0
+
+ phi_target = imconvert_float32(target_im)
+ phi_target = phi_target[y0:y1,x0:x1]
+ if phi_target.shape[2] == 4:
+ phi_target = phi_target[:,:,1:4]
+ return phi_target
diff --git a/cli/app/search/util.py b/cli/app/search/util.py
deleted file mode 100644
index a4cdfd9..0000000
--- a/cli/app/search/util.py
+++ /dev/null
@@ -1,78 +0,0 @@
-
-def truncated_z_sample(batch_size, truncation=1., seed=None):
- state = None if seed is None else np.random.RandomState(seed)
- values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state)
- return truncation * values
-
-def one_hot(index, vocab_size=vocab_size):
- index = np.asarray(index)
- if len(index.shape) == 0:
- index = np.asarray([index])
- assert len(index.shape) == 1
- num = index.shape[0]
- output = np.zeros((num, vocab_size), dtype=np.float32)
- output[np.arange(num), index] = 1
- return output
-
-def imconvert_uint8(im):
- im = np.clip(((im + 1) / 2.0) * 256, 0, 255)
- im = np.uint8(im)
- return im
-
-def imconvert_float32(im):
- im = np.float32(im)
- im = (im / 256) * 2.0 - 1
- return im
-
-def imread(filename):
- img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
- if img is not None:
- if len(img.shape) > 2:
- img = img[...,::-1]
- return img
-
-def imwrite(filename, img):
- if img is not None:
- if len(img.shape) > 2:
- img = img[...,::-1]
- return cv2.imwrite(filename, img)
-
-def imgrid(imarray, cols=5, pad=1):
- if imarray.dtype != np.uint8:
- raise ValueError('imgrid input imarray must be uint8')
- pad = int(pad)
- assert pad >= 0
- cols = int(cols)
- assert cols >= 1
- N, H, W, C = imarray.shape
- rows = int(np.ceil(N / float(cols)))
- batch_pad = rows * cols - N
- assert batch_pad >= 0
- post_pad = [batch_pad, pad, pad, 0]
- pad_arg = [[0, p] for p in post_pad]
- imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255)
- H += pad
- W += pad
- grid = (imarray
- .reshape(rows, cols, H, W, C)
- .transpose(0, 2, 1, 3, 4)
- .reshape(rows*H, cols*W, C))
- if pad:
- grid = grid[:-pad, :-pad]
- return grid
-
-def imshow(a, format='png', jpeg_fallback=True):
- a = np.asarray(a, dtype=np.uint8)
- str_file = cStringIO.StringIO()
- PIL.Image.fromarray(a).save(str_file, format)
- im_data = str_file.getvalue()
- try:
- disp = IPython.display.display(IPython.display.Image(im_data))
- except IOError:
- if jpeg_fallback and format != 'jpeg':
- print ('Warning: image was too large to display in format "{}"; '
- 'trying jpeg instead.').format(format)
- return imshow(a, format='jpeg')
- else:
- raise
- return disp
diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py
new file mode 100644
index 0000000..89cd949
--- /dev/null
+++ b/cli/app/search/vector.py
@@ -0,0 +1,20 @@
+import random
+import numpy as np
+from scipy.stats import truncnorm
+
+def truncated_z_sample(batch_size, z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
+ return truncation * values
+
+def truncated_z_single(z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(1, z_dim))
+ return truncation * values
+
+def create_labels(batch_size, vocab_size, num_classes):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for _ in range(random.randint(1, num_classes)):
+ j = random.randint(0, vocab_size-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py
index 7ec107e..bdbbb90 100644
--- a/cli/app/settings/app_cfg.py
+++ b/cli/app/settings/app_cfg.py
@@ -30,6 +30,7 @@ CLICK_GROUPS = {
SELF_CWD = os.path.dirname(os.path.realpath(__file__)) # Script CWD
DIR_APP = str(Path(SELF_CWD).parent.parent.parent)
DIR_IMAGENET = join(DIR_APP, 'data_store/imagenet')
+DIR_INVERSES = join(DIR_APP, 'data_store/inverses')
DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs')
FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml')