summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--live-mogrify.py18
-rw-r--r--models/base_model.py2
-rw-r--r--options/base_options.py6
-rw-r--r--options/dataset_options.py7
4 files changed, 16 insertions, 17 deletions
diff --git a/live-mogrify.py b/live-mogrify.py
index 89ce10c..2372bc9 100644
--- a/live-mogrify.py
+++ b/live-mogrify.py
@@ -39,7 +39,7 @@ def load_opt():
data_opt_parser = DatasetOptions()
data_opt = data_opt_parser.parse(opt.unknown)
global module_name
- module_name = data_opt.module_name
+ module_name = opt.module_name
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
@@ -131,9 +131,9 @@ def process_image(opt, data_opt, im):
img = np.add(array_a, array_b).astype('uint8')
return img
-def list_checkpoints():
+def list_checkpoints(payload):
print("> list checkpoints")
- return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/*/latest_net_G.pth')])
+ return sorted([f.split('/')[2] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')])
def list_epochs(path):
print("> list epochs for {}".format(path))
@@ -141,12 +141,12 @@ def list_epochs(path):
return "not found"
return sorted([f.split('/')[3].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')])
-def list_sequences():
+def list_sequences(module):
print("> list sequences")
- sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module_name)) if os.path.isdir(os.path.join('./sequences/', module_name, name))])
+ sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
results = []
for path in sequences:
- count = len([name for name in os.listdir(os.path.join('./sequences/', module_name, path)) if os.path.isfile(os.path.join('./sequences/', module_name, path, name))])
+ count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
results.append({
'name': path,
'count': count,
@@ -155,7 +155,7 @@ def list_sequences():
def read_sequence(path):
print("> read sequence {}".format(path))
- return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))])
+ return sorted([f for f in glob.glob(os.path.join('./sequences/', path, '*.png'))])
class Listener():
def __init__(self):
@@ -187,11 +187,11 @@ class Listener():
def _cmd_fn(self, cmd, payload):
print("got command {}".format(cmd))
if cmd == 'list_checkpoints':
- return list_checkpoints()
+ return list_checkpoints(payload)
if cmd == 'list_epochs':
return list_epochs(payload)
if cmd == 'list_sequences':
- return list_sequences()
+ return list_sequences(payload)
if cmd == 'load_epoch':
name, epoch = payload.split(':')
print(">>> loading checkpoint {}, epoch {}".format(name, epoch))
diff --git a/models/base_model.py b/models/base_model.py
index 9b55afe..d3d07d4 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -11,7 +11,7 @@ class BaseModel():
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
- self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.module_name, opt.name)
def set_input(self, input):
self.input = input
diff --git a/options/base_options.py b/options/base_options.py
index f3e93c6..575d0d1 100644
--- a/options/base_options.py
+++ b/options/base_options.py
@@ -43,6 +43,12 @@ class BaseOptions():
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
self.parser.add_argument('--center_crop', action='store_true', help='center crop instead of random crop')
self.parser.add_argument('--poll_delay', type=float, default=0.01, help='time to wait before checking for the next frame')
+ self.parser.add_argument(
+ '--module-name',
+ type=str,
+ default='pix2pix',
+ help='module name... basically says where to look for sequences'
+ )
self.initialized = True
diff --git a/options/dataset_options.py b/options/dataset_options.py
index 92b63ec..7ac5a8a 100644
--- a/options/dataset_options.py
+++ b/options/dataset_options.py
@@ -75,13 +75,6 @@ class DatasetOptions(BaseOptions):
## LIVE IMAGE PROCESSING
self.parser.add_argument(
- '--module-name',
- type=str,
- default='pix2pix',
- help='module name... basically says where to look for sequences'
- )
-
- self.parser.add_argument(
'--send-image',
type=str,
default='b',