diff options
Diffstat (limited to 'cli/app/commands/biggan/random.py')
| -rw-r--r-- | cli/app/commands/biggan/random.py | 73 |
1 files changed, 54 insertions, 19 deletions
diff --git a/cli/app/commands/biggan/random.py b/cli/app/commands/biggan/random.py index 3e4bff6..67e46c4 100644 --- a/cli/app/commands/biggan/random.py +++ b/cli/app/commands/biggan/random.py @@ -6,9 +6,17 @@ from app.settings import app_cfg from os.path import join import time import numpy as np +import random +from scipy.stats import truncnorm from PIL import Image +z_dim = { + 128: 120, + 256: 140, + 512: 128, +} + def image_to_uint8(x): """Converts [-1, 1] float array to [0, 255] uint8.""" x = np.asarray(x) @@ -17,12 +25,27 @@ def image_to_uint8(x): x = x.astype(np.uint8) return x +def truncated_z_sample(batch_size, z_dim, truncation): + values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim)) + return truncation * values + +def create_labels(batch_size, vocab_size, num_classes): + label = np.zeros((batch_size, vocab_size)) + for i in range(batch_size): + for _ in range(random.randint(1, num_classes)): + j = random.randint(0, vocab_size-1) + label[i, j] = random.random() + label[i] /= label[i].sum() + return label + @click.command('') +@click.option('-s', '--dims', 'opt_dims', default=256, type=int, + help='Dimensions of BigGAN network (128, 256, 512)') # @click.option('-i', '--input', 'opt_dir_in', required=True, # help='Path to input image glob directory') # @click.option('-r', '--recursive', 'opt_recursive', is_flag=True) @click.pass_context -def cli(ctx): +def cli(ctx, opt_dims): """ Generate a random BigGAN image """ @@ -30,30 +53,42 @@ def cli(ctx): import tensorflow_hub as hub print("Loading module...") - module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2') + module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2') # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2') # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2') - batch_size = 8 - truncation = 0.5 # scalar truncation value in [0.02, 1.0] + inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k) + for k, v in module.get_input_info_dict().items()} + input_z = inputs['z'] + input_y = inputs['y'] + input_trunc = inputs['truncation'] + output = module(inputs) + + z_dim = input_z.shape.as_list()[1] + vocab_size = input_y.shape.as_list()[1] + + sess = tf.compat.v1.Session() + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) - z = truncation * tf.random.truncated_normal([batch_size, 120]) # noise sample + # scalar truncation value in [0.02, 1.0] - # y_index = tf.random.uniform([batch_size], maxval=1000, dtype=tf.int32) - # y = tf.one_hot(y_index, 1000) - y = tf.random.normal([None, 1000]) + batch_size = 8 + truncation = 0.5 - outputs = module(dict(y=y, z=z, truncation=truncation)) + #z = truncation * tf.random.truncated_normal([batch_size, z_dim]) # noise sample + z = truncated_z_sample(batch_size, z_dim, truncation) - with tf.Session() as sess: - sess.run(tf.compat.v1.global_variables_initializer()) - sess.run(tf.compat.v1.tables_initializer()) - results = sess.run(outputs) + for num_classes in [1, 2, 3, 5, 10, 20, 100]: + print(num_classes) + #y = tf.random.gamma([batch_size, 1000], gamma[0], gamma[1]) + #y = np.random.gamma(gamma[0], gamma[1], (batch_size, 1000,)) + y = create_labels(batch_size, vocab_size, num_classes) - print(results) + results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation}) + for sample in results: + sample = image_to_uint8(sample) + img = Image.fromarray(sample, "RGB") + fp_img_out = "{}.png".format(int(time.time() * 1000)) + img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out)) - for sample in results: - sample = image_to_uint8(sample) - img = Image.fromarray(sample, "RGB") - fp_img_out = "{}.png".format(int(time.time() * 1000)) - img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out)) |
