summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py30
1 files changed, 15 insertions, 15 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 2dc4970..979e137 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -85,8 +85,8 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
# --------------------------
class SinParam:
- def __init__(self, name, shape, type="noise"):
- noise = LerpParam(name + '_noise', shape, type=type)
+ def __init__(self, name, shape, datatype="noise"):
+ noise = LerpParam(name + '_noise', shape, datatype=datatype)
orbit_radius = InterpolatorParam(name=name + '_radius', value=0.1)
orbit_speed = InterpolatorParam(name=name + '_speed', value=1.0)
orbit_time = InterpolatorParam(name=name + '_time', value=0.0)
@@ -101,9 +101,9 @@ class SinParam:
self.orbit_time.value += self.orbit_speed.value * dt
class LerpParam:
- def __init__(self, name, shape, type="noise"):
- a = InterpolatorParam(name=name + '_a', shape=shape, type=type)
- b = InterpolatorParam(name=name + '_b', shape=shape, type=type)
+ def __init__(self, name, shape, datatype="noise"):
+ a = InterpolatorParam(name=name + '_a', shape=shape, datatype=datatype)
+ b = InterpolatorParam(name=name + '_b', shape=shape, datatype=datatype)
n = InterpolatorParam(name=name + '_n', value=0.0)
speed = InterpolatorParam(name=name + '_speed', value=0.1)
output = a.variable * (1 - n.variable) + b.variable * n.variable
@@ -135,11 +135,11 @@ class LerpParam:
# --------------------------
class InterpolatorParam:
- def __init__(self, name, dtype=tf.float32, shape=(), value=None, type=float):
+ def __init__(self, name, dtype=tf.float32, shape=(), value=None, datatype=float):
self.scalar = shape == ()
self.shape = shape
- self.type = type
- self.value = (value or 0.0) if type == float else np.zeros(shape)
+ self.datatype = datatype
+ self.value = (value or 0.0) if datatype == float else np.zeros(shape)
self.variable = tf.placeholder(dtype=dtype, shape=shape)
interpolator.opts[name] = self
@@ -147,9 +147,9 @@ class InterpolatorParam:
self.value = value
def randomize(self):
- if self.type == 'noise':
+ if self.datatype == 'noise':
val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value)
- elif self.type == 'label':
+ elif self.datatype == 'label':
val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value)
self.assign(val)
@@ -166,12 +166,12 @@ class Interpolator:
def build(self):
InterpolatorParam(name='truncation', value=1.0),
InterpolatorParam(name='num_classes', value=1.0),
- lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM])
- sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM])
- lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], type="label")
+ lerp_z = LerpParam('latent', [BATCH_SIZE, Z_DIM], datatype="noise")
+ sin_z = SinParam('orbit', [BATCH_SIZE, Z_DIM], datatype="noise")
+ lerp_label = LerpParam('label', [BATCH_SIZE, N_CLASS], datatype="label")
- # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], type='noise')
- # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], type='label')
+ # self.opts['z'] = InterpolatorParam(name='z', shape=[BATCH_SIZE, Z_DIM], datatype='noise')
+ # self.opts['y'] = InterpolatorParam(name='y', shape=[BATCH_SIZE, N_CLASS], datatype='label')
gen_in = {}
gen_in['truncation'] = 1.0 # self.opts['truncation'].variable