summaryrefslogtreecommitdiff
path: root/cli/app/commands
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-09 09:58:49 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-09 09:58:49 +0100
commit898c9cf94d9e58b1174463a08caa51400d529311 (patch)
treed88ee224b38db03ed154dfb0bc7342f67a1d427b /cli/app/commands
parent50711c691295304cc94dbaf7b1178e2057ed9b5e (diff)
improve search with mean?
Diffstat (limited to 'cli/app/commands')
-rw-r--r--cli/app/commands/biggan/search.py46
1 files changed, 24 insertions, 22 deletions
diff --git a/cli/app/commands/biggan/search.py b/cli/app/commands/biggan/search.py
index d1e1a0a..600e698 100644
--- a/cli/app/commands/biggan/search.py
+++ b/cli/app/commands/biggan/search.py
@@ -11,8 +11,8 @@ import random
from scipy.stats import truncnorm
from subprocess import call
import cv2 as cv
-
from PIL import Image
+from glob import glob
def image_to_uint8(x):
"""Converts [-1, 1] float array to [0, 255] uint8."""
@@ -100,24 +100,32 @@ def cli(ctx, opt_fp_in, opt_dims):
import tensorflow as tf
import tensorflow_hub as hub
- module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2')
- # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
- # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
+ generator = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2')
inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
- for k, v in module.get_input_info_dict().items()}
+ for k, v in generator.get_input_info_dict().items()}
input_z = inputs['z']
input_y = inputs['y']
input_trunc = inputs['truncation']
- output = module(inputs)
-
- z_dim = input_z.shape.as_list()[1]
- vocab_size = input_y.shape.as_list()[1]
+ output = generator(inputs)
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(tf.compat.v1.tables_initializer())
+ if os.path.isdir(opt_fp_in):
+ paths = glob(os.path.join(params.input_dir, '*.jpg')) + \
+ glob(os.path.join(params.input_dir, '*.jpeg')) + \
+ glob(os.path.join(params.input_dir, '*.png'))
+ for path in paths:
+ find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, path, opt_dims)
+ else:
+ find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims)
+
+def find_nearest_vector(generator, sess, input_z, input_y, input_trunc, output, opt_fp_in, opt_dims):
+ z_dim = input_z.shape.as_list()[1]
+ vocab_size = input_y.shape.as_list()[1]
+
# scalar truncation value in [0.02, 1.0]
batch_size = 25
@@ -128,17 +136,11 @@ def cli(ctx, opt_fp_in, opt_dims):
num_classes = 1
y = create_labels(batch_size, vocab_size, num_classes)
- fp_frames = "frames_{}".format(int(time.time() * 1000))
- os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
-
- #results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation})
- #for sample in results:
- # sample = image_to_uint8(sample)
- # img = Image.fromarray(sample, "RGB")
- # fp_img_out = "{}.png".format(int(time.time() * 1000))
- # img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out))
-
if opt_fp_in:
+ fn = os.path.basename(opt_fp_in)
+ fbase, ext = os.path.splitext(fn)
+ fp_frames = "frames_{}_{}".format(fbase, int(time.time() * 1000))
+ os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
target_im = imread(opt_fp_in)
w = target_im.shape[1]
h = target_im.shape[0]
@@ -155,6 +157,8 @@ def cli(ctx, opt_fp_in, opt_dims):
phi_target = np.expand_dims(phi_target, 0)
phi_target = np.repeat(phi_target, batch_size, axis=0)
else:
+ fp_frames = "frames_{}".format(int(time.time() * 1000))
+ os.makedirs(join(app_cfg.DIR_OUTPUTS, fp_frames), exist_ok=True)
z_target = np.repeat(truncated_z_single(z_dim, truncation), batch_size, axis=0)
y_target = np.repeat(create_labels(1, vocab_size, 1), batch_size, axis=0)
feed_dict = {input_z: z_target, input_y: y_target, input_trunc: truncation}
@@ -163,8 +167,6 @@ def cli(ctx, opt_fp_in, opt_dims):
target_im = imgrid(imconvert_uint8(phi_target), cols=5)
imwrite(join(app_cfg.DIR_OUTPUTS, fp_frames, 'frame_0000_target.png'), target_im)
- optimizer = tf.keras.optimizers.Adam()
-
#dy_dx = g.gradient(y, x)
cost = tf.reduce_sum(tf.pow(output - phi_target, 2))
dc_dz, dc_dy, = tf.gradients(cost, [input_z, input_y])
@@ -196,7 +198,7 @@ def cli(ctx, opt_fp_in, opt_dims):
y[j] /= y[j].sum()
if i > 200 and i % 100 == 0:
mean = np.mean(y, axis=0)
- y = y / 2 + mean / 2
+ y = y * 3 / 4 + mean / 4
indices = np.logical_or(z <= -2*truncation, z >= +2*truncation)
z[indices] = np.random.randn(np.count_nonzero(indices))