summaryrefslogtreecommitdiff
path: root/inversion/live.py
blob: 672853aed7044a317b0482c954df552a49faa978 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import sys
import glob
import h5py
import numpy as np
import params
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_hub as hub
import time
import visualize as vs
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

from listener import Listener
from params import Params
from params_opt import ParamsOpt

params = Params('params_dense.json')
opt = ParamsOpt()

# --------------------------
# Make directories.
# --------------------------
OUTPUT_DIR = os.path.join('output', tag)
if not os.path.exists(OUTPUT_DIR):
  os.makedirs(OUTPUT_DIR)

# --------------------------
# Load Graph.
# --------------------------
generator = hub.Module(str(params.generator_path))

gen_signature = 'generator'
if 'generator' not in generator.get_signature_names():
  gen_signature = 'default'

input_info = generator.get_input_info_dict(gen_signature)
COND_GAN = True

BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]

def sin(key, shape):
  # speed should be computed outside of tensorflow
  # (so we can recursively update t = last_t + speed)
  noise, noise_a, noise_b, noise_n = lerp('sin_noise', shape)
  scale = tf.get_variable(name=key + '_scale', dtype=tf.float32, shape=(1,))
  t = tf.get_variable(name=key + '_t', dtype=tf.float32, shape=(1,))
  out = tf.sin(t + noise) * scale
  return out, t, scale, noise_a, noise_b, noise_n

def lerp(key, shape):
  a = tf.get_variable(name=key + '_a', dtype=tf.float32, shape=shape)
  b = tf.get_variable(name=key + '_b', dtype=tf.float32, shape=shape)
  n = tf.get_variable(name=key + '_n', dtype=tf.float32, shape=(1,))
  out = a * (1 - n) + b * n
  return out, a, b, n

lerp_z, z_a, z_b, z_n = lerp('latent', [BATCH_SIZE, Z_DIM])
sin_z, sin_t, sin_scale, sin_noise_a, sin_noise_b, sin_noise_n = lerp('sin_z', [BATCH_SIZE, Z_DIM])
lerp_label, label_a, label_b, label_n = lerp('label', [BATCH_SIZE, N_CLASS])

gen_in = dict(params.generator_fixed_inputs)
gen_in['z'] = lerp_z + sin_z
gen_in['y'] = lerp_label
gen_img = generator(gen_in, signature=gen_signature)

# Convert generated image to channels_first.
gen_img = tf.transpose(gen_img, [0, 3, 1, 2])

# layer_name = 'module_apply_' + gen_signature + '/' + params.inv_layer
# gen_encoding = tf.get_default_graph().get_tensor_by_name(layer_name)
# ENC_SHAPE = gen_encoding.get_shape().as_list()[1:]
# encoding = tf.get_variable(name='encoding', dtype=tf.float32,
#                            shape=[BATCH_SIZE,] + ENC_SHAPE)
# tf.contrib.graph_editor.swap_ts(gen_encoding, tf.convert_to_tensor(encoding))

IMG_SHAPE = gen_img.get_shape().as_list()[1:]

t = time.time()

def on_step():
  # local variables to update:
  #   t, sin_speed, sin_t
  # variables to assign:
  #   z_a, z_b, z_n
  #   label_a, label_b, label_n
  #   sin_t, sin_noise, sin_scale, sin_amount
  #   sin_noise_a, sin_noise_b, sin_noise_n
  # sess.run([
  #   target.assign(image_batch)
  # ])
  # sess.run(label.assign(label_batch))
  gen_time = time.time()
  gen_images  = sess.run(gen_img)
  print("Generation time: {:.1f}s".format(time.time() - gen_time))
  # convert to png and send this back...
  out_img = vs.data2img(image_batch[0])
  pass

def run_live():
  while True:
    if on_step():
      break
  sess.close()

if __name__ == '__main__':
  listener = Listener(opt, run_live)
  listener.connect()