summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/search/live.py41
1 files changed, 20 insertions, 21 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index 99bc052..3e7d138 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -84,23 +84,6 @@ def normal_z_sample(shape=(BATCH_SIZE, Z_DIM,)):
return np.random.normal(size=shape)
# --------------------------
-# Disentangled Latents
-# --------------------------
-
-disentangled = {
- 'zoom': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0],
- 'shiftx': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0],
- 'shifty': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0],
- 'rotate2d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0],
- 'rotate3d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0],
-}
-disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')[0]
-disentangled['r'] = disentangled_color[:, 0]
-disentangled['g'] = disentangled_color[:, 1]
-disentangled['b'] = disentangled_color[:, 2]
-disentangled['luminance'] = np.sum(disentangled_color, axis=1)
-
-# --------------------------
# More complex ops
# --------------------------
@@ -212,6 +195,7 @@ class Interpolator:
self.opts = {}
self.sin_params = {}
self.lerp_params = {}
+ self.load_disentangled_latents()
def build(self):
InterpolatorParam(name='truncation', value=1.0)
@@ -228,10 +212,10 @@ class Interpolator:
z_mix = LerpParam('abs_mix', a_in=z_sum, b_in=z_abs, shape=[BATCH_SIZE, Z_DIM], datatype="input")
# Latent - disentangled vectors
- zoom = InterpolatorParam(name='zoom', value=0.0).variable * disentangled['zoom']
- shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * disentangled['shiftx']
- shifty = InterpolatorParam(name='shifty', value=0.0).variable * disentangled['shifty']
- luminance = InterpolatorParam(name='luminance', value=0.0).variable * disentangled['luminance']
+ zoom = InterpolatorParam(name='zoom', value=0.0).variable * self.disentangled['zoom']
+ shiftx = InterpolatorParam(name='shiftx', value=0.0).variable * self.disentangled['shiftx']
+ shifty = InterpolatorParam(name='shifty', value=0.0).variable * self.disentangled['shifty']
+ luminance = InterpolatorParam(name='luminance', value=0.0).variable * self.disentangled['luminance']
disentangled = z_mix.output + zoom + shiftx + shifty + luminance
# Latent - stored vector
@@ -280,6 +264,21 @@ class Interpolator:
sys.stderr.write("Lerp params: {}\n".format(", ".join(self.lerp_params.keys())))
sys.stderr.write("Opts: {}\n".format(", ".join(self.opts.keys())))
+ def load_disentangled_latents(self):
+ self.disentangled = {
+ 'zoom': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'zoom/model.ckpt'), 'walk')[0],
+ 'shiftx': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shiftx/model.ckpt'), 'walk')[0],
+ 'shifty': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'shifty/model.ckpt'), 'walk')[0],
+ 'rotate2d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate2d/model.ckpt'), 'walk')[0],
+ 'rotate3d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate3d/model.ckpt'), 'walk')[0],
+ }
+ disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')[0]
+ self.disentangled['r'] = disentangled_color[:, 0]
+ self.disentangled['g'] = disentangled_color[:, 1]
+ self.disentangled['b'] = disentangled_color[:, 2]
+ self.disentangled['luminance'] = np.sum(disentangled_color, axis=1)
+
+
def get_feed_dict(self):
opt = {}
for param in self.opts.values():