summaryrefslogtreecommitdiff
path: root/inversion/live.py
diff options
context:
space:
mode:
Diffstat (limited to 'inversion/live.py')
-rw-r--r--inversion/live.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 4e42703..b233724 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -45,11 +45,28 @@ BATCH_SIZE = 1
Z_DIM = input_info['z'].get_shape().as_list()[1]
N_CLASS = input_info['y'].get_shape().as_list()[1]
+# --------------------------
+# Initializers
+# --------------------------
+
+def label_sampler(shape=[BATCH_SIZE, N_CLASS]):
+ num_classes = 2
+ label = np.zeros(shape)
+ for i in range(shape[0]):
+ for _ in range(random.randint(1, shape[1])):
+ j = random.randint(0, shape[1]-1)
+ label[i, j] = random.random()
+ label[i] /= label[i].sum()
+ return label
+
+# --------------------------
+# More complex ops
+# --------------------------
def sin(opts, key, shape):
noise = lerp(opts, key + '_noise', shape)
scale = InterpolatorParam(name=key + '_scale')
- time = opts['global']['time'].variable
+ time = opts['time'].variable
out = tf.sin(time + noise) * scale.variable
opts[key] = {
'scale': scale,
@@ -69,9 +86,10 @@ def lerp(opts, key, shape):
return out
class InterpolatorParam:
- def __init__(self, name, dtype=tf.float32, shape=(), value=None):
+ def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"):
self.scalar = shape == ()
self.shape = shape
+ self.type = type
self.value = value or np.zeros(shape)
self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape)
@@ -79,8 +97,12 @@ class InterpolatorParam:
self.value = value
return self.variable.assign(value)
- def randomize(self):
- return self.assign(np.random.normal(size=self.shape))
+ def randomize(self, num_classes):
+ if self.type == 'noise':
+ val = np.random.normal(size=self.shape)
+ elif self.type == 'label':
+ val = label_sampler(shape=self.shape)
+ self.assign(val)
class Interpolator:
def __init__(self):
@@ -114,14 +136,13 @@ class Interpolator:
def get_state(self):
opt = {}
- for group, lookup in self.opts.items():
- for key, param in group.items():
- if param.scalar:
- opt[group][key] = param.value
+ for key, param in self.opts.items():
+ if param.scalar:
+ opt[key] = param.value
return opt
def set_value(self, key, value):
- self.opts[key].assign(value).eval(session=sess)
+ self.opts[key].assign(value)
def on_step(self, i):
gen_time = time.time()