diff options
| author | Jules Laplace <julescarbon@gmail.com> | 2018-06-07 02:07:23 +0200 |
|---|---|---|
| committer | Jules Laplace <julescarbon@gmail.com> | 2018-06-07 02:07:23 +0200 |
| commit | b7b305a6dda37e0ac131d42007c4014e82decd2c (patch) | |
| tree | 52b20357630152cea4cedfd9a0d356112fd59196 /live-mogrify.py | |
| parent | a438c6e88484a68abdd17384e720452ffa2b96bb (diff) | |
augment paths
Diffstat (limited to 'live-mogrify.py')
| -rw-r--r-- | live-mogrify.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/live-mogrify.py b/live-mogrify.py index 89ce10c..2372bc9 100644 --- a/live-mogrify.py +++ b/live-mogrify.py @@ -39,7 +39,7 @@ def load_opt(): data_opt_parser = DatasetOptions() data_opt = data_opt_parser.parse(opt.unknown) global module_name - module_name = data_opt.module_name + module_name = opt.module_name opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle @@ -131,9 +131,9 @@ def process_image(opt, data_opt, im): img = np.add(array_a, array_b).astype('uint8') return img -def list_checkpoints(): +def list_checkpoints(payload): print("> list checkpoints") - return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/*/latest_net_G.pth')]) + return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')]) def list_epochs(path): print("> list epochs for {}".format(path)) @@ -141,12 +141,12 @@ def list_epochs(path): return "not found" return sorted([f.split('/')[3].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')]) -def list_sequences(): +def list_sequences(module): print("> list sequences") - sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module_name)) if os.path.isdir(os.path.join('./sequences/', module_name, name))]) + sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))]) results = [] for path in sequences: - count = len([name for name in os.listdir(os.path.join('./sequences/', module_name, path)) if os.path.isfile(os.path.join('./sequences/', module_name, path, name))]) + count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))]) results.append({ 'name': path, 'count': count, @@ -155,7 +155,7 @@ def list_sequences(): def read_sequence(path): print("> read sequence {}".format(path)) - return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))]) + return sorted([f for f in glob.glob(os.path.join('./sequences/', path, '*.png'))]) class Listener(): def __init__(self): @@ -187,11 +187,11 @@ class Listener(): def _cmd_fn(self, cmd, payload): print("got command {}".format(cmd)) if cmd == 'list_checkpoints': - return list_checkpoints() + return list_checkpoints(payload) if cmd == 'list_epochs': return list_epochs(payload) if cmd == 'list_sequences': - return list_sequences() + return list_sequences(payload) if cmd == 'load_epoch': name, epoch = payload.split(':') print(">>> loading checkpoint {}, epoch {}".format(name, epoch)) |
