diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2019-12-09 09:58:49 +0100 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2019-12-09 09:58:49 +0100 |
| commit | 898c9cf94d9e58b1174463a08caa51400d529311 (patch) | |
| tree | d88ee224b38db03ed154dfb0bc7342f67a1d427b /cli/app/commands | |
| parent | 50711c691295304cc94dbaf7b1178e2057ed9b5e (diff) | |
improve search with mean?
Diffstat (limited to 'cli/app/commands')
| -rw-r--r-- | cli/app/commands/biggan/search.py | 46 |
1 files changed, 24 insertions, 22 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py index d1e1a0a..600e698 100644 --- a/cli/app/commands/biggan/search.py +++ b/cli/app/commands/biggan/search.py @@ -11,8 +11,8 @@ import random from scipy.stats import truncnorm from subprocess import call import cv2 as cv - from PIL import Image +from glob import glob def image_to_uint8(x): """Converts [-1, 1] float array to [0, 255] uint8.""" @@ -100,24 +100,32 @@ def cli(ctx, opt_fp_in, opt_dims): import tensorflow as tf import tensorflow_hub as hub - module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2') - # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2') - # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') + generator = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2') inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) - for k, v in module.get_input_info_dict().items()} + for k, v in generator.get_input_info_dict().items()} input_z = inputs['z'] input_y = inputs['y'] input_trunc = inputs['truncation'] - output = module(inputs) - - z_dim = input_z.shape.as_list()[1] - vocab_size = input_y.shape.as_list()[1] + output = generator(inputs) sess = tf.compat.v1.Session() sess.run(tf.compat.v1.global_variables_initializer()) sess.run(tf.compat.v1.tables_initializer()) + if os.path.isdir(opt_fp_in): + paths = glob(os.path.join(params.input_dir, '*.jpg')) + \ + glob(os.path.join(params.input_dir, '*.jpeg')) + \ + glob(os.path.join(params.input_dir, '*.png')) + for path in paths: + find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims) + else: + find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims) + +def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims): + z_dim = input_z.shape.as_list()[1] + vocab_size = input_y.shape.as_list()[1] + # scalar truncation value in [0.02, 1.0] batch_size = 25 @@ -128,17 +136,11 @@ def cli(ctx, opt_fp_in, opt_dims): num_classes = 1 y = create_labels(batch_size, vocab_size, num_classes) - fp_frames = "frames_{}".format(int(time.time() * 1000)) - os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) - - #results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation}) - #for sample in results: - # sample = image_to_uint8(sample) - # img = Image.fromarray(sample, "RGB") - # fp_img_out = "{}.png".format(int(time.time() * 1000)) - # img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out)) - if opt_fp_in: + fn = os.path.basename(opt_fp_in) + fbase, ext = os.path.splitext(fn) + fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000)) + os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) target_im = imread(opt_fp_in) w = target_im.shape[1] h = target_im.shape[0] @@ -155,6 +157,8 @@ def cli(ctx, opt_fp_in, opt_dims): phi_target = np.expand_dims(phi_target, 0) phi_target = np.repeat(phi_target, batch_size, axis=0) else: + fp_frames = "frames_{}".format(int(time.time() * 1000)) + os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True) z_target = np.repeat(truncated_z_single(z_dim, truncation), batch_size, axis=0) y_target = np.repeat(create_labels(1, vocab_size, 1), batch_size, axis=0) feed_dict = {input_z: z_target, input_y: y_target, input_trunc: truncation} @@ -163,8 +167,6 @@ def cli(ctx, opt_fp_in, opt_dims): target_im = imgrid(imconvert_uint8(phi_target), cols=5) imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_target.png'), target_im) - optimizer = tf.keras.optimizers.Adam() - #dy_dx = g.gradient(y, x) cost = tf.reduce_sum(tf.pow(output - phi_target, 2)) dc_dz, dc_dy, = tf.gradients(cost, [input_z, input_y]) @@ -196,7 +198,7 @@ def cli(ctx, opt_fp_in, opt_dims): y[j] /= y[j].sum() if i > 200 and i % 100 == 0: mean = np.mean(y, axis=0) - y = y / 2 + mean / 2 + y = y * 3 / 4 + mean / 4 indices = np.logical_or(z <= -2*truncation, z >= +2*truncation) z[indices] = np.random.randn(np.count_nonzero(indices)) |
