summaryrefslogtreecommitdiff
path: root/inversion/random_sample.py
blob: 61cac9c440ece3488a064faf59ad1adae0a0a98f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# ------------------------------------------------------------------------------
# Generate random samples of the generator and save the images to a hdf5 file.
# ------------------------------------------------------------------------------

import h5py
import numpy as np
import os
import params
import sys
import tensorflow as tf
import tensorflow_hub as hub
import time
import visualize as vs

# --------------------------
# Hyper-parameters.
# --------------------------
# Expected parameters:
#  generator_path: path to generator module.
#  generator_fixed_inputs: dictionary of fixed generator's input parameters.
#  dataset_out: name for the output created dataset (hdf5 file).
# General parameters:
#  batch_size: number of images generated at the same time.
#  random_label: choose random labels.
#  num_imgs: number of instances to generate.
#  custom_label: custom label to be fixed.
# Logging:
#  sample_size: number of images included in sampled images.
if len(sys.argv) < 2:
  sys.exit('Must provide a configuration file.')
params = params.Params(sys.argv[1])

# --------------------------
# Hyper-parameters.
# --------------------------
# General parameters.
BATCH_SIZE = params.batch_size
SAMPLE_SIZE = params.sample_size
assert SAMPLE_SIZE <= BATCH_SIZE
NUM_IMGS = params.num_imgs

# --------------------------
# Global directories.
# --------------------------
SAMPLES_DIR = 'random_samples'
INVERSES_DIR = 'inverses'
if not os.path.exists(SAMPLES_DIR):
  os.makedirs(SAMPLES_DIR)
if not os.path.exists(INVERSES_DIR):
  os.makedirs(INVERSES_DIR)

# --------------------------
# Util functions.
# --------------------------
def one_hot(values):
  return np.eye(N_CLASS)[values]

def label_sampler(size=1):
  return np.random.random_integers(low=0, high=N_CLASS-1, size=size)

# --------------------------
# Load Graph.
# --------------------------
generator = hub.Module(str(params.generator_path))

gen_signature = 'generator'
if 'generator' not in generator.get_signature_names():
  gen_signature = 'default'

input_info = generator.get_input_info_dict(gen_signature)
COND_GAN = 'y' in input_info

if COND_GAN:
  Z_DIM = input_info['z'].get_shape().as_list()[1]
  latent = tf.get_variable(name='latent', dtype=tf.float32,
                           shape=[BATCH_SIZE, Z_DIM])
  N_CLASS = input_info['y'].get_shape().as_list()[1]
  label = tf.get_variable(name='label', dtype=tf.float32,
                          shape=[BATCH_SIZE, N_CLASS])
  gen_in = dict(params.generator_fixed_inputs)
  gen_in['z'] = latent
  gen_in['y'] = label
  gen_img = generator(gen_in, signature=gen_signature)
else:
  Z_DIM = input_info['default'].get_shape().as_list()[1]
  latent = tf.get_variable(name='latent', dtype=tf.float32,
                           shape=[BATCH_SIZE, Z_DIM])
  if (params.generator_fixed_inputs):
    gen_in = dict(params.generator_fixed_inputs)
    gen_in['z'] = latent
    gen_img = generator(gen_in, signature=gen_signature)
  else:
    gen_img = generator(latent, signature=gen_signature)

# Convert generated image to channels_first.
gen_img = tf.transpose(gen_img, [0, 3, 1, 2])

# Define image shape.
IMG_SHAPE = gen_img.get_shape().as_list()[1:]

# --------------------------
# Noise source.
# --------------------------
def noise_sampler():
  return np.random.normal(size=[BATCH_SIZE, Z_DIM])

# --------------------------
# Generation.
# --------------------------
# Start session.
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
sess.run(tf.tables_initializer())

# Output file.
out_file = h5py.File(os.path.join(INVERSES_DIR, params.dataset_out), 'w')
out_images = out_file.create_dataset('xtrain', [NUM_IMGS,] + IMG_SHAPE,
                                     dtype='uint8')
if COND_GAN:
  out_labels = out_file.create_dataset('ytrain', (NUM_IMGS,), dtype='uint32')

for i in range(0, NUM_IMGS, BATCH_SIZE):
  n_encs = min(BATCH_SIZE, NUM_IMGS - i)

  if COND_GAN:
    if params.random_label:
      label_batch = label_sampler(BATCH_SIZE)
    else:
      label_batch = [params.custom_label]*BATCH_SIZE
    sess.run(label.assign(one_hot(label_batch)))

  sess.run(latent.assign(noise_sampler()))

  gen_images = sess.run(gen_img)

  gen_images = vs.data2img(gen_images)

  out_images[i:i+n_encs] = gen_images[:n_encs]
  if COND_GAN:
    out_labels[i:i+n_encs] = label_batch[:n_encs]

  out_batch = vs.grid_transform(gen_images[:SAMPLE_SIZE])
  vs.save_image('{}/generated_{}.png'.format(SAMPLES_DIR, i), out_batch)
  print('Saved samples for imgs: {}-{}.'.format(i,i+n_encs))