summaryrefslogtreecommitdiff
path: root/cli/app/commands/biggan/random.py
diff options
context:
space:
mode:
Diffstat (limited to 'cli/app/commands/biggan/random.py')
-rw-r--r--cli/app/commands/biggan/random.py73
1 files changed, 54 insertions, 19 deletions
diff --git a/cli/app/commands/biggan/random.py b/cli/app/commands/biggan/random.py
index 3e4bff6..67e46c4 100644
--- a/cli/app/commands/biggan/random.py
+++ b/cli/app/commands/biggan/random.py
@@ -6,9 +6,17 @@ from app.settings import app_cfg
from os.path import join
import time
import numpy as np
+import random
+from scipy.stats import truncnorm
from PIL import Image
+z_dim = {
+ 128: 120,
+ 256: 140,
+ 512: 128,
+}
+
def image_to_uint8(x):
"""Converts [-1, 1] float array to [0, 255] uint8."""
x = np.asarray(x)
@@ -17,12 +25,27 @@ def image_to_uint8(x):
x = x.astype(np.uint8)
return x
+def truncated_z_sample(batch_size, z_dim, truncation):
+ values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
+ return truncation * values
+
+def create_labels(batch_size, vocab_size, num_classes):
+ label = np.zeros((batch_size, vocab_size))
+ for i in range(batch_size):
+ for _ in range(random.randint(1, num_classes)):
+ j = random.randint(0, vocab_size-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
+
@click.command('')
+@click.option('-s', '--dims', 'opt_dims', default=256, type=int,
+ help='Dimensions of BigGAN network (128, 256, 512)')
# @click.option('-i', '--input', 'opt_dir_in', required=True,
# help='Path to input image glob directory')
# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
@click.pass_context
-def cli(ctx):
+def cli(ctx, opt_dims):
"""
Generate a random BigGAN image
"""
@@ -30,30 +53,42 @@ def cli(ctx):
import tensorflow_hub as hub
print("Loading module...")
- module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
+ module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2')
# module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
# module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')
- batch_size = 8
- truncation = 0.5 # scalar truncation value in [0.02, 1.0]
+ inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
+ for k, v in module.get_input_info_dict().items()}
+ input_z = inputs['z']
+ input_y = inputs['y']
+ input_trunc = inputs['truncation']
+ output = module(inputs)
+
+ z_dim = input_z.shape.as_list()[1]
+ vocab_size = input_y.shape.as_list()[1]
+
+ sess = tf.compat.v1.Session()
+ sess.run(tf.compat.v1.global_variables_initializer())
+ sess.run(tf.compat.v1.tables_initializer())
- z = truncation * tf.random.truncated_normal([batch_size, 120]) # noise sample
+ # scalar truncation value in [0.02, 1.0]
- # y_index = tf.random.uniform([batch_size], maxval=1000, dtype=tf.int32)
- # y = tf.one_hot(y_index, 1000)
- y = tf.random.normal([None, 1000])
+ batch_size = 8
+ truncation = 0.5
- outputs = module(dict(y=y, z=z, truncation=truncation))
+ #z = truncation * tf.random.truncated_normal([batch_size, z_dim]) # noise sample
+ z = truncated_z_sample(batch_size, z_dim, truncation)
- with tf.Session() as sess:
- sess.run(tf.compat.v1.global_variables_initializer())
- sess.run(tf.compat.v1.tables_initializer())
- results = sess.run(outputs)
+ for num_classes in [1, 2, 3, 5, 10, 20, 100]:
+ print(num_classes)
+ #y = tf.random.gamma([batch_size, 1000], gamma[0], gamma[1])
+ #y = np.random.gamma(gamma[0], gamma[1], (batch_size, 1000,))
+ y = create_labels(batch_size, vocab_size, num_classes)
- print(results)
+ results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation})
+ for sample in results:
+ sample = image_to_uint8(sample)
+ img = Image.fromarray(sample, "RGB")
+ fp_img_out = "{}.png".format(int(time.time() * 1000))
+ img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out))
- for sample in results:
- sample = image_to_uint8(sample)
- img = Image.fromarray(sample, "RGB")
- fp_img_out = "{}.png".format(int(time.time() * 1000))
- img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out))