summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2019-12-19 18:30:44 +0100
committerJules Laplace <julescarbon@gmail.com>2019-12-19 18:30:44 +0100
commita09eda7cd2033bb6cd18f50ef5be4c5eb2e984f0 (patch)
treee9ebb7d222b64cd33dac7f7b3050aa1a4fe442da /inversion
parent468af07857297ba0cbf01eda2442501dc7351c33 (diff)
type
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 7f07189..3e40be5 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -73,8 +73,8 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
# More complex ops
# --------------------------
-def sin(opts, key, shape):
- noise = lerp(opts, key + '_noise', shape)
+def sin(opts, key, shape, type=type):
+ noise = lerp(opts, key + '_noise', shape, type=type)
scale = InterpolatorParam(name=key + '_scale', value=0.1)
speed = InterpolatorParam(name=key + '_speed', value=1.0)
time = opts['time'].variable
@@ -83,9 +83,9 @@ def sin(opts, key, shape):
opts[key + '_speed'] = speed
return out
-def lerp(opts, key, shape):
- a = InterpolatorParam(name=key + '_a', shape=shape)
- b = InterpolatorParam(name=key + '_b', shape=shape)
+def lerp(opts, key, shape, type='noise'):
+ a = InterpolatorParam(name=key + '_a', shape=shape, type=type)
+ b = InterpolatorParam(name=key + '_b', shape=shape, type=type)
n = InterpolatorParam(name=key + '_n', value=0.0)
speed = InterpolatorParam(name=key + '_speed', value=0.1)
out = a.variable * (1 - n.variable) + b.variable * n.variable
@@ -133,8 +133,8 @@ class Interpolator:
# sin_z = sin(self.opts, 'orbit', [BATCH_SIZE, Z_DIM])
# lerp_label = lerp(self.opts, 'label', [BATCH_SIZE, N_CLASS])
- self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM])
- self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS])
+ self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise')
+ self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label')
gen_in = {}
gen_in['truncation'] = 1.0 # self.opts['truncation'].variable
@@ -212,6 +212,8 @@ class Listener:
for i in range(99999):
if (i % 100) == 0:
print("Step {}".format(i))
+ self.opts['z'].randomize(truncation=1.0)
+ self.opts['y'].randomize(num_classes=1.0)
gen_time = time.time()
self.interpolator.opts['time'].assign(gen_time)
gen_images = self.interpolator.on_step(i, self.sess)