summaryrefslogtreecommitdiff
path: root/live-mogrify.py
diff options
context:
space:
mode:
authorJules Laplace <julescarbon@gmail.com>2018-06-07 02:07:23 +0200
committerJules Laplace <julescarbon@gmail.com>2018-06-07 02:07:23 +0200
commitb7b305a6dda37e0ac131d42007c4014e82decd2c (patch)
tree52b20357630152cea4cedfd9a0d356112fd59196 /live-mogrify.py
parenta438c6e88484a68abdd17384e720452ffa2b96bb (diff)
augment paths
Diffstat (limited to 'live-mogrify.py')
-rw-r--r--live-mogrify.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/live-mogrify.py b/live-mogrify.py
index 89ce10c..2372bc9 100644
--- a/live-mogrify.py
+++ b/live-mogrify.py
@@ -39,7 +39,7 @@ def load_opt():
data_opt_parser = DatasetOptions()
data_opt = data_opt_parser.parse(opt.unknown)
global module_name
- module_name = data_opt.module_name
+ module_name = opt.module_name
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
@@ -131,9 +131,9 @@ def process_image(opt, data_opt, im):
img = np.add(array_a, array_b).astype('uint8')
return img
-def list_checkpoints():
+def list_checkpoints(payload):
print("> list checkpoints")
- return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/*/latest_net_G.pth')])
+ return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')])
def list_epochs(path):
print("> list epochs for {}".format(path))
@@ -141,12 +141,12 @@ def list_epochs(path):
return "not found"
return sorted([f.split('/')[3].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')])
-def list_sequences():
+def list_sequences(module):
print("> list sequences")
- sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module_name)) if os.path.isdir(os.path.join('./sequences/', module_name, name))])
+ sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
results = []
for path in sequences:
- count = len([name for name in os.listdir(os.path.join('./sequences/', module_name, path)) if os.path.isfile(os.path.join('./sequences/', module_name, path, name))])
+ count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
results.append({
'name': path,
'count': count,
@@ -155,7 +155,7 @@ def list_sequences():
def read_sequence(path):
print("> read sequence {}".format(path))
- return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))])
+ return sorted([f for f in glob.glob(os.path.join('./sequences/', path, '*.png'))])
class Listener():
def __init__(self):
@@ -187,11 +187,11 @@ class Listener():
def _cmd_fn(self, cmd, payload):
print("got command {}".format(cmd))
if cmd == 'list_checkpoints':
- return list_checkpoints()
+ return list_checkpoints(payload)
if cmd == 'list_epochs':
return list_epochs(payload)
if cmd == 'list_sequences':
- return list_sequences()
+ return list_sequences(payload)
if cmd == 'load_epoch':
name, epoch = payload.split(':')
print(">>> loading checkpoint {}, epoch {}".format(name, epoch))