summaryrefslogtreecommitdiff
path: root/inversion/live.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/live.py')
-rw-r--r--inversion/live.py31
1 files changed, 27 insertions, 4 deletions
diff --git a/inversion/live.py b/inversion/live.py
index cf7f761..ad4b80f 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -42,6 +42,24 @@ BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]
+# --------------------------
+# Initializers
+# --------------------------
+
+def label_sampler(shape=[BATCH_SIZE, N_CLASS]):
+ num_classes = 2
+ label = np.zeros(shape)
+ for i in range(shape[0]):
+ for _ in range(random.randint(1, shape[1])):
+ j = random.randint(0, shape[1]-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
+
+# --------------------------
+# More complex ops
+# --------------------------
+
def sin(opts, key, shape):
noise = lerp(opts, key + '_noise', shape)
scale = InterpolatorParam(name=key + '_scale')
@@ -61,17 +79,22 @@ def lerp(opts, key, shape):
return out
class InterpolatorParam:
- def __init__(self, name, dtype=tf.float32, shape=(), value=None):
+ def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"):
self.scalar = shape == ()
self.shape = shape
+ self.type = type
self.value = value or np.zeros(shape)
self.variable = tf.placeholder(dtype=dtype, shape=shape)
def assign(self, value):
self.value = value
- def randomize(self):
- return self.assign(np.random.normal(size=self.shape))
+ def randomize(self, num_classes):
+ if self.type == 'noise':
+ val = np.random.normal(size=self.shape)
+ elif self.type == 'label':
+ val = label_sampler(shape=self.shape)
+ self.assign(val)
class Interpolator:
def __init__(self):
@@ -108,7 +131,7 @@ class Interpolator:
return opt
def set_value(self, key, value):
- return self.opts[key].assign(value)
+ self.opts[key].assign(value)
def on_step(self, i, sess):
gen_images = sess.run(self.gen_img, feed_dict=self.get_feed_dict())