diff options
| -rw-r--r-- | cli/app/search/live.py | 41 |
1 files changed, 20 insertions, 21 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py index 99bc052..3e7d138 100644 --- a/cli/app/search/live.py +++ b/cli/app/search/live.py @@ -84,23 +84,6 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)): return np.random.normal(size=shape) # -------------------------- -# Disentangled Latents -# -------------------------- - -disentangled = { - 'zoom': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0], - 'shiftx': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0], - 'shifty': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0], - 'rotate2d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0], - 'rotate3d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0], -} -disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')[0] -disentangled['r'] = disentangled_color[:, 0] -disentangled['g'] = disentangled_color[:, 1] -disentangled['b'] = disentangled_color[:, 2] -disentangled['luminance'] = np.sum(disentangled_color, axis=1) - -# -------------------------- # More complex ops # -------------------------- @@ -212,6 +195,7 @@ class Interpolator: self.opts = {} self.sin_params = {} self.lerp_params = {} + self.load_disentangled_latents() def build(self): InterpolatorParam(name='truncation', value=1.0) @@ -228,10 +212,10 @@ class Interpolator: z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input") # Latent - disentangled vectors - zoom = InterpolatorParam(name='zoom', value=0.0).variable * disentangled['zoom'] - shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * disentangled['shiftx'] - shifty = InterpolatorParam(name='shifty', value=0.0).variable * disentangled['shifty'] - luminance = InterpolatorParam(name='luminance', value=0.0).variable * disentangled['luminance'] + zoom = InterpolatorParam(name='zoom', value=0.0).variable * self.disentangled['zoom'] + shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * self.disentangled['shiftx'] + shifty = InterpolatorParam(name='shifty', value=0.0).variable * self.disentangled['shifty'] + luminance = InterpolatorParam(name='luminance', value=0.0).variable * self.disentangled['luminance'] disentangled = z_mix.output + zoom + shiftx + shifty + luminance # Latent - stored vector @@ -280,6 +264,21 @@ class Interpolator: sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys()))) sys.stderr.write("Opts: {}\n".format(", ".join(self.opts.keys()))) + def load_disentangled_latents(self): + self.disentangled = { + 'zoom': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0], + 'shiftx': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0], + 'shifty': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0], + 'rotate2d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0], + 'rotate3d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0], + } + disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')[0] + self.disentangled['r'] = disentangled_color[:, 0] + self.disentangled['g'] = disentangled_color[:, 1] + self.disentangled['b'] = disentangled_color[:, 2] + self.disentangled['luminance'] = np.sum(disentangled_color, axis=1) + + def get_feed_dict(self): opt = {} for param in self.opts.values(): |
