summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--cli/app/search/live.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/cli/app/search/live.py b/cli/app/search/live.py
index 398e25e..9fcf91b 100644
--- a/cli/app/search/live.py
+++ b/cli/app/search/live.py
@@ -272,11 +272,11 @@ class Interpolator:
'rotate2d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate2d/model.ckpt'), 'walk'),
'rotate3d': read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'rotate3d/model.ckpt'), 'walk'),
}
- disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')[0]
- self.disentangled['r'] = disentangled_color[:, 0]
- self.disentangled['g'] = disentangled_color[:, 1]
- self.disentangled['b'] = disentangled_color[:, 2]
- self.disentangled['luminance'] = np.sum(disentangled_color, axis=1)
+ disentangled_color = read_checkpoint(os.path.join(app_cfg.DIR_DISENTANGLED, 'color/model.ckpt'), 'walk')
+ self.disentangled['r'] = disentangled_color[:, :, 0]
+ self.disentangled['g'] = disentangled_color[:, :, 1]
+ self.disentangled['b'] = disentangled_color[:, :, 2]
+ self.disentangled['luminance'] = np.sum(disentangled_color, axis=2)
def get_feed_dict(self):