From d9af07ade49673185f5bcfb1a007e9c5a05c59af Mon Sep 17 00:00:00 2001 From: Jules Laplace Date: Thu, 19 Dec 2019 20:54:02 +0100 Subject: sending command --- inversion/live.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'inversion') diff --git a/inversion/live.py b/inversion/live.py index 711df22..84bfa25 100644 --- a/inversion/live.py +++ b/inversion/live.py @@ -139,7 +139,10 @@ class InterpolatorParam: self.scalar = shape == () self.shape = shape self.datatype = datatype - self.value = (value or 0.0) if datatype == "float" else np.zeros(shape) + if datatype == "float": + self.assign(value or 0.0) + else: + self.randomize() self.variable = tf.placeholder(dtype=dtype, shape=shape) interpolator.opts[name] = self @@ -151,9 +154,11 @@ class InterpolatorParam: def randomize(self): if self.datatype == 'noise': - val = truncated_z_sample(shape=self.shape, truncation=interpolator.opt['truncation'].value) + val = truncated_z_sample(shape=self.shape, truncation=interpolator.opts['truncation'].value) elif self.datatype == 'label': - val = label_sampler(shape=self.shape, num_classes=interpolator.opt['num_classes'].value) + val = label_sampler(shape=self.shape, num_classes=interpolator.opts['num_classes'].value) + else: + val = 0.0 self.assign(val) # -------------------------- -- cgit v1.2.3-70-g09d2