diff options
| -rw-r--r-- | inversion/live.py | 39 |
1 files changed, 30 insertions, 9 deletions
diff --git a/inversion/live.py b/inversion/live.py index 4e42703..b233724 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -45,11 +45,28 @@ BATCH_SIZE = 1 Z_DIM = input_info['z'].get_shape().as_list()[1] N_CLASS = input_info['y'].get_shape().as_list()[1] +# -------------------------- +# Initializers +# -------------------------- + +def label_sampler(shape=[BATCH_SIZE, N_CLASS]): + num_classes = 2 + label = np.zeros(shape) + for i in range(shape[0]): + for _ in range(random.randint(1, shape[1])): + j = random.randint(0, shape[1]-1) + label[i, j] = random.random() + label[i] /= label[i].sum() + return label + +# -------------------------- +# More complex ops +# -------------------------- def sin(opts, key, shape): noise = lerp(opts, key + '_noise', shape) scale = InterpolatorParam(name=key + '_scale') - time = opts['global']['time'].variable + time = opts['time'].variable out = tf.sin(time + noise) * scale.variable opts[key] = { 'scale': scale, @@ -69,9 +86,10 @@ def lerp(opts, key, shape): return out class InterpolatorParam: - def __init__(self, name, dtype=tf.float32, shape=(), value=None): + def __init__(self, name, dtype=tf.float32, shape=(), value=None, type="noise"): self.scalar = shape == () self.shape = shape + self.type = type self.value = value or np.zeros(shape) self.variable = tf.Variable(self.value, name=name, dtype=dtype, shape=shape) @@ -79,8 +97,12 @@ class InterpolatorParam: self.value = value return self.variable.assign(value) - def randomize(self): - return self.assign(np.random.normal(size=self.shape)) + def randomize(self, num_classes): + if self.type == 'noise': + val = np.random.normal(size=self.shape) + elif self.type == 'label': + val = label_sampler(shape=self.shape) + self.assign(val) class Interpolator: def __init__(self): @@ -114,14 +136,13 @@ class Interpolator: def get_state(self): opt = {} - for group, lookup in self.opts.items(): - for key, param in group.items(): - if param.scalar: - opt[group][key] = param.value + for key, param in self.opts.items(): + if param.scalar: + opt[key] = param.value return opt def set_value(self, key, value): - self.opts[key].assign(value).eval(session=sess) + self.opts[key].assign(value) def on_step(self, i): gen_time = time.time() |
