summaryrefslogtreecommitdiff
path: root/inversion
diff options
context:
space:
mode:
Diffstat (limited to 'inversion')
-rw-r--r--inversion/live.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/inversion/live.py b/inversion/live.py
index 711df22..84bfa25 100644
--- a/inversion/live.py
+++ b/inversion/live.py
@@ -139,7 +139,10 @@ class InterpolatorParam:
self.scalar = shape == ()
self.shape = shape
self.datatype = datatype
- self.value = (value or 0.0) if datatype == "float" else np.zeros(shape)
+ if datatype == "float":
+ self.assign(value or 0.0)
+ else:
+ self.randomize()
self.variable = tf.placeholder(dtype=dtype, shape=shape)
interpolator.opts[name] = self
@@ -151,9 +154,11 @@ class InterpolatorParam:
def randomize(self):
if self.datatype == 'noise':
- val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value)
+ val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value)
elif self.datatype == 'label':
- val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value)
+ val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value)
+ else:
+ val = 0.0
self.assign(val)
# --------------------------