summaryrefslogtreecommitdiff
path: root/live-mogrify.py
diff options
context:
space:
mode:
Diffstat (limited to 'live-mogrify.py')
-rw-r--r--live-mogrify.py315
1 files changed, 10 insertions, 305 deletions
diff --git a/live-mogrify.py b/live-mogrify.py
index 9f3c8a3..decccc8 100644
--- a/live-mogrify.py
+++ b/live-mogrify.py
@@ -23,23 +23,16 @@ import glob
import gevent
from time import sleep
-from rpc import CortexRPC
+from img_ops import process_image, mix_next_image
+from listener import Listener
module_name = 'pix2pix'
-def clamp(n,a,b):
- return max(a, min(n, b))
-
-def lerp(n,a,b):
- return (b-a)*n+a
-
def load_opt():
opt_parser = TestOptions()
opt = opt_parser.parse()
data_opt_parser = DatasetOptions()
data_opt = data_opt_parser.parse(opt.unknown)
- global module_name
- module_name = opt.module_name
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
@@ -94,197 +87,7 @@ def load_first_frame(opt, data_opt, i=0):
print("Sequence not found")
return A_offset, A_im, A_dir
-def process_image(opt, data_opt, im):
- img = im[:, :, ::-1].copy()
- processed = False
-
- if data_opt.process_frac == 0:
- return img
-
- if data_opt.clahe is True:
- processed = True
- lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
- l, a, b = cv2.split(lab)
- clahe = cv2.createCLAHE(clipLimit=data_opt.clip_limit, tileGridSize=(8,8))
- l = clahe.apply(l)
- limg = cv2.merge((l,a,b))
- img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
- if data_opt.posterize is True:
- processed = True
- img = cv2.pyrMeanShiftFiltering(img, data_opt.spatial_window, data_opt.color_window)
- if data_opt.grayscale is True:
- processed = True
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- if data_opt.blur is True:
- processed = True
- img = cv2.GaussianBlur(img, (data_opt.blur_radius, data_opt.blur_radius), data_opt.blur_sigma)
- if data_opt.canny is True:
- processed = True
- img = cv2.Canny(img, data_opt.canny_lo, data_opt.canny_hi)
- img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
-
- if processed is False:
- return img
-
- src_img = im[:, :, ::-1].copy()
- frac_a = data_opt.process_frac
- frac_b = 1.0 - frac_a
- array_a = np.multiply(src_img.astype('float64'), frac_a)
- array_b = np.multiply(img.astype('float64'), frac_b)
- img = np.add(array_a, array_b).astype('uint8')
- return img
-
-def list_checkpoints(payload):
- print("> list checkpoints")
- return sorted([f.split('/')[3] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')])
-
-def list_epochs(path):
- print("> list epochs for {}".format(path))
- if not os.path.exists(os.path.join('./checkpoints/', path)):
- return "not found"
- return sorted([f.split('/')[4].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')])
-
-def list_sequences(module):
- print("> list sequences")
- sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
- results = []
- for path in sequences:
- count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
- results.append({
- 'name': path,
- 'count': count,
- })
- return results
-
-import torchvision.transforms as transforms
-
-# def get_transform(opt={}):
-# transform_list = []
-# if opt.resize_or_crop == 'resize_and_crop':
-# osize = [opt.loadSize, opt.loadSize]
-# transform_list.append(transforms.Scale(osize, Image.BICUBIC))
-# if opt.center_crop:
-# transform_list.append(transforms.CenterCrop(opt.fineSize))
-# else:
-# transform_list.append(transforms.RandomCrop(opt.fineSize))
-# # elif opt.resize_or_crop == 'crop':
-# # transform_list.append(transforms.RandomCrop(opt.fineSize))
-# # elif opt.resize_or_crop == 'scale_width':
-# # transform_list.append(transforms.Lambda(
-# # lambda img: __scale_width(img, opt.fineSize)))
-# # elif opt.resize_or_crop == 'scale_width_and_crop':
-# # transform_list.append(transforms.Lambda(
-# # lambda img: __scale_width(img, opt.loadSize)))
-# # transform_list.append(transforms.RandomCrop(opt.fineSize))
-
-# # if opt.isTrain and not opt.no_flip:
-# # transform_list.append(transforms.RandomHorizontalFlip())
-
-# transform_list += [transforms.ToTensor(),
-# transforms.Normalize((0.5, 0.5, 0.5),
-# (0.5, 0.5, 0.5))]
-# return transforms.Compose(transform_list)
-
-def load_frame(opt, index):
- A_path = os.path.join(opt.render_dir, "frame_{:05d}.png".format(index))
- if not os.path.exists(A_path):
- print("path doesn't exist: {}".format(A_path))
- return None
- transform = get_transform(opt)
- A_img = Image.open(A_path).convert('RGB')
- A = transform(A_img)
- # if self.opt.which_direction == 'BtoA':
- # input_nc = self.opt.output_nc
- # else:
- # input_nc = self.opt.input_nc
-
- # if input_nc == 1: # RGB to gray
- # tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
- # A = tmp.unsqueeze(0)
-
- return {'A': A, 'A_paths': A_path}
-
-def read_sequence(path):
- print("> read sequence {}".format(path))
- return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))])
-
-class Listener():
- def __init__(self):
- opt, data_opt, data_opt_parser = load_opt()
- self.opt = opt
- self.data_opt = data_opt
- self.data_opt_parser = data_opt_parser.parser
- self.model = create_model(opt)
- self.data_opt.load_checkpoint = True
- self.working = False
- def _set_fn(self, key, value):
- if hasattr(self.data_opt, key):
- try:
- if str(value) == 'True':
- setattr(self.data_opt, key, True)
- print('set {} {}: {}'.format('bool', key, True))
- elif str(value) == 'False':
- setattr(self.data_opt, key, False)
- print('set {} {}: {}'.format('bool', key, False))
- else:
- new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ])
- new_value = getattr(new_opt, key)
- setattr(self.data_opt, key, new_value)
- print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value))
- except Exception as e:
- print('error {} - cant set value {}: {}'.format(e, key, value))
- def _get_fn(self):
- return vars(self.data_opt)
- def _cmd_fn(self, cmd, payload):
- print("got command {}".format(cmd))
- if cmd == 'list_checkpoints':
- return list_checkpoints(payload)
- if cmd == 'list_epochs':
- return list_epochs(payload)
- if cmd == 'list_sequences':
- return list_sequences(payload)
- if cmd == 'load_epoch':
- name, epoch = payload.split(':')
- print(">>> loading checkpoint {}, epoch {}".format(name, epoch))
- self.data_opt.checkpoint_name = name
- self.data_opt.epoch = epoch
- self.data_opt.load_checkpoint = True
- return 'ok'
- if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', self.opt.module_name, payload)):
- print('load sequence: {}'.format(payload))
- self.data_opt.sequence_name = payload
- self.data_opt.load_sequence = True
- return 'loaded sequence'
- if cmd == 'seek':
- self.data_opt.seek_to = payload
- return 'seeked'
- if cmd == 'get_status':
- return {
- 'processing': self.data_opt.processing,
- 'checkpoint': self.data_opt.checkpoint_name,
- 'epoch': self.data_opt.epoch,
- 'sequence': self.data_opt.sequence_name,
- }
- if cmd == 'play' and self.data_opt.processing is False:
- self.data_opt.pause = False
- process_live_input(self.opt, self.data_opt, self.rpc_client, self.model)
- return 'playing'
- if cmd == 'pause' and self.data_opt.processing is True:
- self.data_opt.pause = True
- return 'paused'
- if cmd == 'exit':
- print("Exiting now...!")
- sys.exit(0)
- return 'exited'
- return 'ok'
- def _ready_fn(self, rpc_client):
- print("Ready!")
- self.rpc_client = rpc_client
- process_live_input(self.opt, self.data_opt, rpc_client, self.model)
- def connect(self):
- self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)
-
-def process_live_input(opt, data_opt, rpc_client, model):
+def process_live_input(opt, data_opt, rpc_client):
print(">>> Process live input")
if data_opt.processing:
print("Already processing...")
@@ -293,7 +96,7 @@ def process_live_input(opt, data_opt, rpc_client, model):
dataset = data_loader.load_data()
create_render_dir(opt)
- sequence = read_sequence(data_opt.sequence_name)
+ sequence = read_sequence(data_opt.sequence_name, opt.module_name)
print("Got sequence {}, {} images".format(data_opt.sequence, len(sequence)))
if len(sequence) == 0:
print("Got empty sequence...")
@@ -307,17 +110,11 @@ def process_live_input(opt, data_opt, rpc_client, model):
start_img_path = os.path.join(opt.render_dir, "frame_{:05d}.png".format(0))
copyfile(sequence[0], start_img_path)
- last_im = None
+ model = create_model(opt)
- print("generating...")
sequence_i = 1
- i = 0
- # while True:
- # data = load_frame(opt, i)
- # if data is None:
- # print("got no frame, exiting")
- # break
+ print("generating...")
for i, data in enumerate(data_loader):
if i >= opt.how_many:
print("generated {} images, exiting".format(i))
@@ -346,109 +143,17 @@ def process_live_input(opt, data_opt, rpc_client, model):
visuals = model.get_current_visuals()
img_path = model.get_image_paths()
- if (i % 100) == 0:
- print('%04d: process image...' % (i))
-
- im = visuals['fake_B']
- last_path = opt.render_dir + "frame_{:05d}.png".format(i)
- tmp_path = opt.render_dir + "frame_{:05d}_tmp.png".format(i+1)
- next_path = opt.render_dir + "frame_{:05d}.png".format(i+1)
- current_path = opt.render_dir + "ren_{:05d}.png".format(i+1)
- meta = { 'i': i, 'sequence_i': sequence_i, 'sequence_len': len(sequence) }
- if data_opt.sequence and len(sequence):
- sequence_path = sequence[sequence_i]
- if sequence_i >= len(sequence)-1:
- print('(((( sequence looped ))))')
- sequence_i = 1
- else:
- sequence_i += 1
-
- if data_opt.send_image == 'b':
- image_pil = Image.fromarray(im, mode='RGB')
- rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, image_pil, data_opt.output_format)
-
- if data_opt.store_a is not True:
- os.remove(last_path)
- if data_opt.store_b is True:
- image_pil = Image.fromarray(im, mode='RGB')
- image_pil.save(tmp_path)
- os.rename(tmp_path, current_path)
-
- if data_opt.recursive and last_im is not None:
- if data_opt.sequence and len(sequence):
- A_img = Image.open(sequence_path).convert('RGB')
- A_im = np.asarray(A_img)
- frac_a = data_opt.recursive_frac
- frac_b = data_opt.sequence_frac
- frac_sum = frac_a + frac_b
- if frac_sum > 1.0:
- frac_a = frac_a / frac_sum
- frac_b = frac_b / frac_sum
- if data_opt.transition:
- t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max)
- frac_a *= 1.0 - t
- frac_b *= 1.0 - t
- frac_c = 1.0 - frac_a - frac_b
- array_a = np.multiply(last_im.astype('float64'), frac_a)
- array_b = np.multiply(A_im.astype('float64'), frac_b)
- array_c = np.multiply(im.astype('float64'), frac_c)
- array_ab = np.add(array_a, array_b)
- array_abc = np.add(array_ab, array_c)
- next_im = array_abc.astype('uint8')
-
- else:
- frac_a = data_opt.recursive_frac
- frac_b = 1.0 - frac_a
- array_a = np.multiply(last_im.astype('float64'), frac_a)
- array_b = np.multiply(im.astype('float64'), frac_b)
- next_im = np.add(array_a, array_b).astype('uint8')
-
- if data_opt.recurse_roll != 0:
- last_im = np.roll(im, data_opt.recurse_roll, axis=data_opt.recurse_roll_axis)
- else:
- last_im = next_im.copy().astype('uint8')
+ sequence_i = mix_next_image(opt, data_opt, visuals['fake_B'], i, sequence, sequence_i)
- elif data_opt.sequence and len(sequence):
- A_img = Image.open(sequence_path).convert('RGB')
- A_im = np.asarray(A_img)
- frac_b = data_opt.sequence_frac
- if data_opt.transition:
- t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max)
- frac_b *= 1.0 - t
- frac_c = 1.0 - frac_b
- array_b = np.multiply(A_im.astype('float64'), frac_b)
- array_c = np.multiply(im.astype('float64'), frac_c)
- array_bc = np.add(array_b, array_c)
- next_im = array_bc.astype('uint8')
-
- else:
- last_im = im.copy().astype('uint8')
- next_im = im
-
- next_img = process_image(opt, data_opt, next_im)
-
- if data_opt.send_image == 'sequence':
- rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, A_img, data_opt.output_format)
- if data_opt.send_image == 'recursive':
- pil_im = Image.fromarray(next_im)
- rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format)
- if data_opt.send_image == 'a':
- rgb_im = cv2.cvtColor(next_img, cv2.COLOR_BGR2RGB)
- pil_im = Image.fromarray(rgb_im)
- rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format)
-
- cv2.imwrite(tmp_path, next_img)
- os.rename(tmp_path, next_path)
- if (i % 20) == 0:
- print("created {}".format(next_path))
if data_opt.pause:
data_opt.pause = False
break
gevent.sleep(data_opt.frame_delay)
- i += 1
+
data_opt.processing = False
rpc_client.send_status('processing', False)
if __name__ == '__main__':
- listener = Listener()
+ opt, data_opt, data_opt_parser = load_opt()
+ listener = Listener(opt, data_opt, data_opt_parser, process_live_input)
listener.connect()