diff options
| -rw-r--r-- | live-mogrify.py | 18 | ||||
| -rw-r--r-- | models/base_model.py | 2 | ||||
| -rw-r--r-- | options/base_options.py | 6 | ||||
| -rw-r--r-- | options/dataset_options.py | 7 |
4 files changed, 16 insertions, 17 deletions
diff --git a/live-mogrify.py b/live-mogrify.py index 89ce10c..2372bc9 100644 --- a/live-mogrify.py +++ b/live-mogrify.py @@ -39,7 +39,7 @@ def load_opt(): data_opt_parser = DatasetOptions() data_opt = data_opt_parser.parse(opt.unknown) global module_name - module_name = data_opt.module_name + module_name = opt.module_name opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle @@ -131,9 +131,9 @@ def process_image(opt, data_opt, im): img = np.add(array_a, array_b).astype('uint8') return img -def list_checkpoints(): +def list_checkpoints(payload): print("> list checkpoints") - return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/*/latest_net_G.pth')]) + return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')]) def list_epochs(path): print("> list epochs for {}".format(path)) @@ -141,12 +141,12 @@ def list_epochs(path): return "not found" return sorted([f.split('/')[3].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')]) -def list_sequences(): +def list_sequences(module): print("> list sequences") - sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module_name)) if os.path.isdir(os.path.join('./sequences/', module_name, name))]) + sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))]) results = [] for path in sequences: - count = len([name for name in os.listdir(os.path.join('./sequences/', module_name, path)) if os.path.isfile(os.path.join('./sequences/', module_name, path, name))]) + count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))]) results.append({ 'name': path, 'count': count, @@ -155,7 +155,7 @@ def list_sequences(): def read_sequence(path): print("> read sequence {}".format(path)) - return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))]) + return sorted([f for f in glob.glob(os.path.join('./sequences/', path, '*.png'))]) class Listener(): def __init__(self): @@ -187,11 +187,11 @@ class Listener(): def _cmd_fn(self, cmd, payload): print("got command {}".format(cmd)) if cmd == 'list_checkpoints': - return list_checkpoints() + return list_checkpoints(payload) if cmd == 'list_epochs': return list_epochs(payload) if cmd == 'list_sequences': - return list_sequences() + return list_sequences(payload) if cmd == 'load_epoch': name, epoch = payload.split(':') print(">>> loading checkpoint {}, epoch {}".format(name, epoch)) diff --git a/models/base_model.py b/models/base_model.py index 9b55afe..d3d07d4 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -11,7 +11,7 @@ class BaseModel(): self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor - self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + self.save_dir = os.path.join(opt.checkpoints_dir, opt.module_name, opt.name) def set_input(self, input): self.input = input diff --git a/options/base_options.py b/options/base_options.py index f3e93c6..575d0d1 100644 --- a/options/base_options.py +++ b/options/base_options.py @@ -43,6 +43,12 @@ class BaseOptions(): self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--center_crop', action='store_true', help='center crop instead of random crop') self.parser.add_argument('--poll_delay', type=float, default=0.01, help='time to wait before checking for the next frame') + self.parser.add_argument( + '--module-name', + type=str, + default='pix2pix', + help='module name... basically says where to look for sequences' + ) self.initialized = True diff --git a/options/dataset_options.py b/options/dataset_options.py index 92b63ec..7ac5a8a 100644 --- a/options/dataset_options.py +++ b/options/dataset_options.py @@ -75,13 +75,6 @@ class DatasetOptions(BaseOptions): ## LIVE IMAGE PROCESSING self.parser.add_argument( - '--module-name', - type=str, - default='pix2pix', - help='module name... basically says where to look for sequences' - ) - - self.parser.add_argument( '--send-image', type=str, default='b', |
