summaryrefslogtreecommitdiff
path: root/cli/app/commands/biggan/random.py
blob: 67e46c44301ffcd08af53e0835aa0c5742186d6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import click

from app.utils import click_utils
from app.settings import app_cfg

from os.path import join
import time
import numpy as np
import random
from scipy.stats import truncnorm

from PIL import Image

z_dim = {
  128: 120,
  256: 140,
  512: 128,
}

def image_to_uint8(x):
  """Converts [-1, 1] float array to [0, 255] uint8."""
  x = np.asarray(x)
  x = (256. / 2.) * (x + 1.)
  x = np.clip(x, 0, 255)
  x = x.astype(np.uint8)
  return x

def truncated_z_sample(batch_size, z_dim, truncation):
  values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim))
  return truncation * values

def create_labels(batch_size, vocab_size, num_classes):
  label = np.zeros((batch_size, vocab_size))
  for i in range(batch_size):
    for _ in range(random.randint(1, num_classes)):
      j = random.randint(0, vocab_size-1)
      label[i, j] = random.random()
    label[i] /= label[i].sum()
  return label

@click.command('')
@click.option('-s', '--dims', 'opt_dims', default=256, type=int,
  help='Dimensions of BigGAN network (128, 256, 512)')
# @click.option('-i', '--input', 'opt_dir_in', required=True, 
#   help='Path to input image glob directory')
# @click.option('-r', '--recursive', 'opt_recursive', is_flag=True)
@click.pass_context
def cli(ctx, opt_dims):
  """
  Generate a random BigGAN image
  """
  import tensorflow as tf
  import tensorflow_hub as hub

  print("Loading module...")
  module = hub.Module('https://tfhub.dev/deepmind/biggan-' + str(opt_dims) + '/2')
  # module = hub.Module('https://tfhub.dev/deepmind/biggan-256/2')
  # module = hub.Module('https://tfhub.dev/deepmind/biggan-512/2')

  inputs = {k: tf.compat.v1.placeholder(v.dtype, v.get_shape().as_list(), k)
          for k, v in module.get_input_info_dict().items()}
  input_z = inputs['z']
  input_y = inputs['y']
  input_trunc = inputs['truncation']
  output = module(inputs)

  z_dim = input_z.shape.as_list()[1]
  vocab_size = input_y.shape.as_list()[1]

  sess = tf.compat.v1.Session()
  sess.run(tf.compat.v1.global_variables_initializer())
  sess.run(tf.compat.v1.tables_initializer())

  # scalar truncation value in [0.02, 1.0]

  batch_size = 8 
  truncation = 0.5

  #z = truncation * tf.random.truncated_normal([batch_size, z_dim])  # noise sample
  z = truncated_z_sample(batch_size, z_dim, truncation)

  for num_classes in [1, 2, 3, 5, 10, 20, 100]:
    print(num_classes)
    #y = tf.random.gamma([batch_size, 1000], gamma[0], gamma[1])
    #y = np.random.gamma(gamma[0], gamma[1], (batch_size, 1000,))
    y = create_labels(batch_size, vocab_size, num_classes)

    results = sess.run(output, feed_dict={input_z: z, input_y: y, input_trunc: truncation})
    for sample in results:
      sample = image_to_uint8(sample)
      img = Image.fromarray(sample, "RGB")
      fp_img_out = "{}.png".format(int(time.time() * 1000))
      img.save(join(app_cfg.DIR_OUTPUTS, fp_img_out))