From 326db345db13b1ab3a76406644654cb78b4d1b8d Mon Sep 17 00:00:00 2001 From: "jules@lens" Date: Sun, 8 Dec 2019 11:59:27 +0100 Subject: biggan search test --- cli/app/commands/biggan/test.py | 116 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 cli/app/commands/biggan/test.py (limited to 'cli/app/commands/biggan/test.py') diff --git a/cli/app/commands/biggan/test.py b/cli/app/commands/biggan/test.py new file mode 100644 index 0000000..593d557 --- /dev/null +++ b/cli/app/commands/biggan/test.py @@ -0,0 +1,116 @@ +import click + +from app.utils import click_utils +from app.settings import app_cfg + +import os +from os.path import join +import time +import numpy as np +import random +from scipy.stats import truncnorm +from subprocess import call +import cv2 as cv + +from PIL import Image + +def image_to_uint8(x): + """Converts [-1, 1] float array to [0, 255] uint8.""" + x = np.asarray(x) + x = (256. / 2.) * (x + 1.) + x = np.clip(x, 0, 255) + x = x.astype(np.uint8) + return x + +def truncated_z_sample(batch_size, z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) + return truncation * values + +def truncated_z_single(z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(1, z_dim)) + return truncation * values + +def create_labels(batch_size, vocab_size, num_classes): + label = np.zeros((batch_size, vocab_size)) + for i in range(batch_size): + for _ in range(random.randint(1, num_classes)): + j = random.randint(0, vocab_size-1) + label[i, j] = random.random() + label[i] /= label[i].sum() + return label + +def imconvert_uint8(im): + im = np.clip(((im + 1) / 2.0) * 256, 0, 255) + im = np.uint8(im) + return im + +def imconvert_float32(im): + im = np.float32(im) + im = (im / 256) * 2.0 - 1 + return im + +def imread(filename): + img = cv.imread(filename, cv.IMREAD_UNCHANGED) + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return img + +def imwrite(filename, img): + if img is not None: + if len(img.shape) > 2: + img = img[...,::-1] + return cv.imwrite(filename, img) + +def imgrid(imarray, cols=5, pad=1): + if imarray.dtype != np.uint8: + raise ValueError('imgrid input imarray must be uint8') + pad = int(pad) + assert pad >= 0 + cols = int(cols) + assert cols >= 1 + N, H, W, C = imarray.shape + rows = int(np.ceil(N / float(cols))) + batch_pad = rows * cols - N + assert batch_pad >= 0 + post_pad = [batch_pad, pad, pad, 0] + pad_arg = [[0, p] for p in post_pad] + imarray = np.pad(imarray, pad_arg, 'constant', constant_values=255) + H += pad + W += pad + grid = (imarray + .reshape(rows, cols, H, W, C) + .transpose(0, 2, 1, 3, 4) + .reshape(rows*H, cols*W, C)) + if pad: + grid = grid[:-pad, :-pad] + return grid + +@click.command('') +@click.option('-i', '--input', 'opt_fp_in', required=True, + help='Path to input image') +@click.option('-s', '--dims', 'opt_dims', default=128, type=int, + help='Dimensions of BigGAN network (128, 256, 512)') +# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) +@click.pass_context +def cli(ctx, opt_fp_in, opt_dims): + """ + Search for an image in BigGAN using gradient descent + """ + if opt_fp_in: + target_im = imread(opt_fp_in) + w = target_im.shape[1] + h = target_im.shape[0] + if w <= h: + scale = opt_dims / w + else: + scale = opt_dims / h + print("{} {}".format(w, h)) + target_im = cv.resize(target_im,(0,0), fx=scale, fy=scale) + phi_target = target_im[:opt_dims,:opt_dims] + print(phi_target.shape) + print(phi_target[64,64]) + if phi_target.shape[2] == 4: + phi_target_a = phi_target[:,:,1:4] + imwrite('crop.png', phi_target_a) + -- cgit v1.2.3-70-g09d2