diff options
Diffstat (limited to 'cli/app')
| -rw-r--r-- | cli/app/commands/biggan/search.py | 131 | ||||
| -rw-r--r-- | cli/app/search/image.py | 87 | ||||
| -rw-r--r-- | cli/app/search/util.py | 78 | ||||
| -rw-r--r-- | cli/app/search/vector.py | 20 | ||||
| -rw-r--r-- | cli/app/settings/app_cfg.py | 1 |
5 files changed, 146 insertions, 171 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py index e764487..f1cf385 100644 --- a/cli/app/commands/biggan/search.py +++ b/cli/app/commands/biggan/search.py @@ -8,7 +8,6 @@ from os.path import join import time import numpy as np import random -from scipy.stats import truncnorm from subprocess import call import cv2 as cv from PIL import Image @@ -16,87 +15,22 @@ from glob import glob import tensorflow as tf import tensorflow_hub as hub import shutil +import h5py -def image_to_uint8(x): - """Converts [-1, 1] float array to [0, 255] uint8.""" - x = np.asarray(x) - x = (256. / 2.) * (x + 1.) - x = np.clip(x, 0, 255) - x = x.astype(np.uint8) - return x - -def truncated_z_sample(batch_size, z_dim, truncation): - values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) - return truncation * values - -def truncated_z_single(z_dim, truncation): - values = truncnorm.rvs(-2, 2, size=(1, z_dim)) - return truncation * values - -def create_labels(batch_size, vocab_size, num_classes): - label = np.zeros((batch_size, vocab_size)) - for i in range(batch_size): - for _ in range(random.randint(1, num_classes)): - j = random.randint(0, vocab_size-1) - label[i, j] = random.random() - label[i] /= label[i].sum() - return label - -def imconvert_uint8(im): - im = np.clip(((im + 1) / 2.0) * 256, 0, 255) - im = np.uint8(im) - return im - -def imconvert_float32(im): - im = np.float32(im) - im = (im / 256) * 2.0 - 1 - return im - -def imread(filename): - img = cv.imread(filename, cv.IMREAD_UNCHANGED) - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return img - -def imwrite(filename, img): - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return cv.imwrite(filename, img) - -def imgrid(imarray, cols=5, pad=1): - if imarray.dtype != np.uint8: - raise ValueError('imgrid input imarray must be uint8') - pad = int(pad) - assert pad >= 0 - cols = int(cols) - assert cols >= 1 - N, H, W, C = imarray.shape - rows = int(np.ceil(N / float(cols))) - batch_pad = rows * cols - N - assert batch_pad >= 0 - post_pad = [batch_pad, pad, pad, 0] - pad_arg = [[0, p] for p in post_pad] - imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) - H += pad - W += pad - grid = (imarray - .reshape(rows, cols, H, W, C) - .transpose(0, 2, 1, 3, 4) - .reshape(rows*H, cols*W, C)) - if pad: - grid = grid[:-pad, :-pad] - return grid +from app.search.image import image_to_uint8, imconvert_uint8, imconvert_float32, \ + imread, imwrite, imgrid +from app.search.vector import truncated_z_sample, truncated_z_single, create_labels @click.command('') @click.option('-i', '--input', 'opt_fp_in', required=True, help='Path to input image') @click.option('-s', '--dims', 'opt_dims', default=128, type=int, help='Dimensions of BigGAN network (128, 256, 512)') +@click.option('-v', '--video', 'opt_video', is_flag=True, + help='Export a video for each dataset') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context -def cli(ctx, opt_fp_in, opt_dims): +def cli(ctx, opt_fp_in, opt_dims, opt_video): """ Search for an image in BigGAN using gradient descent """ @@ -109,6 +43,9 @@ def cli(ctx, opt_fp_in, opt_dims): input_trunc = inputs['truncation'] output = generator(inputs) + z_dim = input_z.shape.as_list()[1] + vocab_size = input_y.shape.as_list()[1] + sess = tf.compat.v1.Session() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) @@ -117,12 +54,24 @@ def cli(ctx, opt_fp_in, opt_dims): paths = glob(os.path.join(opt_fp_in, '*.jpg')) + \ glob(os.path.join(opt_fp_in, '*.jpeg')) + \ glob(os.path.join(opt_fp_in, '*.png')) - for path in paths: - find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims) else: - find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims) + paths = [opt_fp_in] + + fp_inverses = os.path.join(app_cfg.INVERSES_DIR, params.dataset_out) + os.makedirs(fp_inverses, exist_ok=True) + out_file = h5py.File(fp_inverses, 'w') + out_images = out_file.create_dataset('xtrain', [len(paths), 3, 512, 512], dtype='float32') + out_labels = out_file.create_dataset('ytrain', [len(paths), vocab_size], dtype='float32') + + for path, index in enumerate(paths): + fp_frames = find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims, out_images, out_labels, index) + if opt_video: + export_video(fp_frames) -def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims): +def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims, out_images, out_labels, index): + """ + Find the closest latent and class vectors for an image. Store the class vector in an HDF5. + """ z_dim = input_z.shape.as_list()[1] vocab_size = input_y.shape.as_list()[1] @@ -143,18 +92,14 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) target_im = imread(opt_fp_in) - w = target_im.shape[1] - h = target_im.shape[0] - if w <= h: - scale = opt_dims / w - else: - scale = opt_dims / h - #print("{} {}".format(w, h)) - target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale) - phi_target = imconvert_float32(target_im) - phi_target = phi_target[:opt_dims,:opt_dims] - if phi_target.shape[2] == 4: - phi_target = phi_target[:,:,1:4] + # crop image to 512 and save for later processing + phi_target_for_inversion = resize_and_crop_image(target_im, 512) + b = np.dsplit(phi_target_for_inversion, 3) + phi_target_for_inversion = np.stack(b).reshape((3, 512, 512)) + out_images[index] = phi_target_for_inversion + + # crop image to 128 to find vectors + phi_target = resize_and_crop_image(target_im, opt_dims) phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) else: @@ -211,11 +156,12 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_{:04d}.png'.format(i)), guess_im) if i % 20 == 0: print('lr: {}, iter: {}, grad_z: {}, grad_y: {}'.format(lr_z, i, np.std(grad_z), np.std(grad_y))) - #print('lr: {}, iter: {}, grad_z: {}'.format(lr, i, np.std(grad_z))) - #print('lr: {}, iter: {}, grad_y: {}'.format(lr, i, np.std(grad_y))) except KeyboardInterrupt: pass + out_labels[index] = y + return fp_frames +def export_video(fp_frames): print("Exporting video...") cmd = [ '/home/lens/bin/ffmpeg', @@ -225,7 +171,6 @@ def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, '-pix_fmt', 'yuv420p', join(app_cfg.DIR_OUTPUTS, fp_frames + '.mp4') ] - print(' '.join(cmd)) + # print(' '.join(cmd)) call(cmd) shutil.rmtree(join(app_cfg.DIR_OUTPUTS, fp_frames)) - diff --git a/cli/app/search/image.py b/cli/app/search/image.py new file mode 100644 index 0000000..f800a33 --- /dev/null +++ b/cli/app/search/image.py @@ -0,0 +1,87 @@ +import cv2 as cv +import numpy as np + +def image_to_uint8(x): + """Converts [-1, 1] float array to [0, 255] uint8.""" + x = np.asarray(x) + x = (256. / 2.) * (x + 1.) + x = np.clip(x, 0, 255) + x = x.astype(np.uint8) + return x + +def imconvert_uint8(im): + im = np.clip(((im + 1) / 2.0) * 256, 0, 255) + im = np.uint8(im) + return im + +def imconvert_float32(im): + im = np.float32(im) + im = (im / 256) * 2.0 - 1 + return im + +def imread(filename): + img = cv.imread(filename, cv.IMREAD_UNCHANGED) + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return img + +def imwrite(filename, img): + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return cv.imwrite(filename, img) + +def imgrid(imarray, cols=5, pad=1): + if imarray.dtype != np.uint8: + raise ValueError('imgrid input imarray must be uint8') + pad = int(pad) + assert pad >= 0 + cols = int(cols) + assert cols >= 1 + N, H, W, C = imarray.shape + rows = int(np.ceil(N / float(cols))) + batch_pad = rows * cols - N + assert batch_pad >= 0 + post_pad = [batch_pad, pad, pad, 0] + pad_arg = [[0, p] for p in post_pad] + imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) + H += pad + W += pad + grid = (imarray + .reshape(rows, cols, H, W, C) + .transpose(0, 2, 1, 3, 4) + .reshape(rows*H, cols*W, C)) + if pad: + grid = grid[:-pad, :-pad] + return grid + +def resize_and_crop_image(target_im, opt_dims): + w = target_im.shape[1] + h = target_im.shape[0] + if w <= h: + scale = opt_dims / w + else: + scale = opt_dims / h + #print("{} {}".format(w, h)) + target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale) + + w = target_im.shape[1] + h = target_im.shape[0] + + x0 = 0 + x1 = opt_dims + y0 = 0 + y1 = opt_dims + if w > opt_dims: + x0 += int((w - opt_dims) / 2) + x1 += x0 + if h > opt_dims: + y0 += int((h - opt_dims) / 2) + y1 += y0 + + phi_target = imconvert_float32(target_im) + phi_target = phi_target[y0:y1,x0:x1] + if phi_target.shape[2] == 4: + phi_target = phi_target[:,:,1:4] + return phi_target diff --git a/cli/app/search/util.py b/cli/app/search/util.py deleted file mode 100644 index a4cdfd9..0000000 --- a/cli/app/search/util.py +++ /dev/null @@ -1,78 +0,0 @@ - -def truncated_z_sample(batch_size, truncation=1., seed=None): - state = None if seed is None else np.random.RandomState(seed) - values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state) - return truncation * values - -def one_hot(index, vocab_size=vocab_size): - index = np.asarray(index) - if len(index.shape) == 0: - index = np.asarray([index]) - assert len(index.shape) == 1 - num = index.shape[0] - output = np.zeros((num, vocab_size), dtype=np.float32) - output[np.arange(num), index] = 1 - return output - -def imconvert_uint8(im): - im = np.clip(((im + 1) / 2.0) * 256, 0, 255) - im = np.uint8(im) - return im - -def imconvert_float32(im): - im = np.float32(im) - im = (im / 256) * 2.0 - 1 - return im - -def imread(filename): - img = cv2.imread(filename, cv2.IMREAD_UNCHANGED) - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return img - -def imwrite(filename, img): - if img is not None: - if len(img.shape) > 2: - img = img[...,::-1] - return cv2.imwrite(filename, img) - -def imgrid(imarray, cols=5, pad=1): - if imarray.dtype != np.uint8: - raise ValueError('imgrid input imarray must be uint8') - pad = int(pad) - assert pad >= 0 - cols = int(cols) - assert cols >= 1 - N, H, W, C = imarray.shape - rows = int(np.ceil(N / float(cols))) - batch_pad = rows * cols - N - assert batch_pad >= 0 - post_pad = [batch_pad, pad, pad, 0] - pad_arg = [[0, p] for p in post_pad] - imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) - H += pad - W += pad - grid = (imarray - .reshape(rows, cols, H, W, C) - .transpose(0, 2, 1, 3, 4) - .reshape(rows*H, cols*W, C)) - if pad: - grid = grid[:-pad, :-pad] - return grid - -def imshow(a, format='png', jpeg_fallback=True): - a = np.asarray(a, dtype=np.uint8) - str_file = cStringIO.StringIO() - PIL.Image.fromarray(a).save(str_file, format) - im_data = str_file.getvalue() - try: - disp = IPython.display.display(IPython.display.Image(im_data)) - except IOError: - if jpeg_fallback and format != 'jpeg': - print ('Warning: image was too large to display in format "{}"; ' - 'trying jpeg instead.').format(format) - return imshow(a, format='jpeg') - else: - raise - return disp diff --git a/cli/app/search/vector.py b/cli/app/search/vector.py new file mode 100644 index 0000000..89cd949 --- /dev/null +++ b/cli/app/search/vector.py @@ -0,0 +1,20 @@ +import random +import numpy as np +from scipy.stats import truncnorm + +def truncated_z_sample(batch_size, z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) + return truncation * values + +def truncated_z_single(z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(1, z_dim)) + return truncation * values + +def create_labels(batch_size, vocab_size, num_classes): + label = np.zeros((batch_size, vocab_size)) + for i in range(batch_size): + for _ in range(random.randint(1, num_classes)): + j = random.randint(0, vocab_size-1) + label[i, j] = random.random() + label[i] /= label[i].sum() + return label diff --git a/cli/app/settings/app_cfg.py b/cli/app/settings/app_cfg.py index 7ec107e..bdbbb90 100644 --- a/cli/app/settings/app_cfg.py +++ b/cli/app/settings/app_cfg.py @@ -30,6 +30,7 @@ CLICK_GROUPS = { SELF_CWD = os.path.dirname(os.path.realpath(__file__)) # Script CWD DIR_APP = str(Path(SELF_CWD).parent.parent.parent) DIR_IMAGENET = join(DIR_APP, 'data_store/imagenet') +DIR_INVERSES = join(DIR_APP, 'data_store/inverses') DIR_OUTPUTS = join(DIR_APP, 'data_store/outputs') FP_MODELZOO = join(DIR_APP, 'modelzoo/modelzoo.yaml') |
