diff options
Diffstat (limited to 'rpc')
| -rw-r--r-- | rpc/img_ops.py | 151 | ||||
| -rw-r--r-- | rpc/listener.py | 110 |
2 files changed, 261 insertions, 0 deletions
diff --git a/rpc/img_ops.py b/rpc/img_ops.py new file mode 100644 index 0000000..07e0bf1 --- /dev/null +++ b/rpc/img_ops.py @@ -0,0 +1,151 @@ +import numpy as np +import cv2 +import math + +def clamp(n,a,b): + return max(a, min(n, b)) + +def lerp(n,a,b): + return (b-a)*n+a + +def process_image(opt, data_opt, im): + img = im[:, :, ::-1].copy() + processed = False + + if data_opt.process_frac == 0: + return img + + if data_opt.clahe is True: + processed = True + lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) + l, a, b = cv2.split(lab) + clahe = cv2.createCLAHE(clipLimit=data_opt.clip_limit, tileGridSize=(8,8)) + l = clahe.apply(l) + limg = cv2.merge((l,a,b)) + img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) + if data_opt.posterize is True: + processed = True + img = cv2.pyrMeanShiftFiltering(img, data_opt.spatial_window, data_opt.color_window) + if data_opt.grayscale is True: + processed = True + img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + if data_opt.blur is True: + processed = True + img = cv2.GaussianBlur(img, (data_opt.blur_radius, data_opt.blur_radius), data_opt.blur_sigma) + if data_opt.canny is True: + processed = True + img = cv2.Canny(img, data_opt.canny_lo, data_opt.canny_hi) + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + if processed is False: + return img + + src_img = im[:, :, ::-1].copy() + frac_a = data_opt.process_frac + frac_b = 1.0 - frac_a + array_a = np.multiply(src_img.astype('float64'), frac_a) + array_b = np.multiply(img.astype('float64'), frac_b) + img = np.add(array_a, array_b).astype('uint8') + return img + +last_im = None +def mix_next_image(opt, data_opt, im, i, sequence, sequence_i): + global last_im + + if (i % 100) == 0: + print('%04d: process image...' % (i)) + + last_path = opt.render_dir + "frame_{:05d}.png".format(i) + tmp_path = opt.render_dir + "frame_{:05d}_tmp.png".format(i+1) + next_path = opt.render_dir + "frame_{:05d}.png".format(i+1) + current_path = opt.render_dir + "ren_{:05d}.png".format(i+1) + meta = { 'i': i, 'sequence_i': sequence_i, 'sequence_len': len(sequence) } + if data_opt.sequence and len(sequence): + sequence_path = sequence[sequence_i] + if sequence_i >= len(sequence)-1: + print('(((( sequence looped ))))') + sequence_i = 1 + else: + sequence_i += 1 + + if data_opt.store_a is not True: + os.remove(last_path) + + if data_opt.store_b is True: + image_pil = Image.fromarray(im, mode='RGB') + image_pil.save(tmp_path) + os.rename(tmp_path, current_path) + + if data_opt.recursive and last_im is not None: + if data_opt.sequence and len(sequence): + A_img = Image.open(sequence_path).convert('RGB') + A_im = np.asarray(A_img) + frac_a = data_opt.recursive_frac + frac_b = data_opt.sequence_frac + frac_sum = frac_a + frac_b + if frac_sum > 1.0: + frac_a = frac_a / frac_sum + frac_b = frac_b / frac_sum + if data_opt.transition: + t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max) + frac_a *= 1.0 - t + frac_b *= 1.0 - t + frac_c = 1.0 - frac_a - frac_b + array_a = np.multiply(last_im.astype('float64'), frac_a) + array_b = np.multiply(A_im.astype('float64'), frac_b) + array_c = np.multiply(im.astype('float64'), frac_c) + array_ab = np.add(array_a, array_b) + array_abc = np.add(array_ab, array_c) + next_im = array_abc.astype('uint8') + + else: + frac_a = data_opt.recursive_frac + frac_b = 1.0 - frac_a + array_a = np.multiply(last_im.astype('float64'), frac_a) + array_b = np.multiply(im.astype('float64'), frac_b) + next_im = np.add(array_a, array_b).astype('uint8') + + if data_opt.recurse_roll != 0: + last_im = np.roll(im, data_opt.recurse_roll, axis=data_opt.recurse_roll_axis) + else: + last_im = next_im.copy().astype('uint8') + + elif data_opt.sequence and len(sequence): + A_img = Image.open(sequence_path).convert('RGB') + A_im = np.asarray(A_img) + frac_b = data_opt.sequence_frac + if data_opt.transition: + t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max) + frac_b *= 1.0 - t + frac_c = 1.0 - frac_b + array_b = np.multiply(A_im.astype('float64'), frac_b) + array_c = np.multiply(im.astype('float64'), frac_c) + array_bc = np.add(array_b, array_c) + next_im = array_bc.astype('uint8') + + else: + last_im = im.copy().astype('uint8') + next_im = im + + next_img = process_image(opt, data_opt, next_im) + + if data_opt.send_image == 'sequence': + rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, A_img, data_opt.output_format) + if data_opt.send_image == 'recursive': + pil_im = Image.fromarray(next_im) + rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format) + if data_opt.send_image == 'a': + rgb_im = cv2.cvtColor(next_img, cv2.COLOR_BGR2RGB) + pil_im = Image.fromarray(rgb_im) + rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format) + if data_opt.send_image == 'b': + image_pil = Image.fromarray(im, mode='RGB') + rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, image_pil, data_opt.output_format) + + cv2.imwrite(tmp_path, next_img) + os.rename(tmp_path, next_path) + + if (i % 20) == 0: + print("created {}".format(next_path)) + + return sequence_i
\ No newline at end of file diff --git a/rpc/listener.py b/rpc/listener.py new file mode 100644 index 0000000..a9c571d --- /dev/null +++ b/rpc/listener.py @@ -0,0 +1,110 @@ +import os +import sys + +from rpc import CortexRPC +from img_ops import process_image + +def list_checkpoints(payload): + print("> list checkpoints") + return sorted([f.split('/')[3] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')]) + +def list_epochs(path): + print("> list epochs for {}".format(path)) + if not os.path.exists(os.path.join('./checkpoints/', path)): + return "not found" + return sorted([f.split('/')[4].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')]) + +def list_sequences(module): + print("> list sequences") + sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))]) + results = [] + for path in sequences: + count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))]) + results.append({ + 'name': path, + 'count': count, + }) + return results + +class Listener(): + def __init__(self, opt, data_opt, data_opt_parser, run_live): + self.opt = opt + self.data_opt = data_opt + self.data_opt_parser = data_opt_parser.parser + self.data_opt.load_checkpoint = True + self.working = False + self.run_live = run_live + def _set_fn(self, key, value): + if hasattr(self.data_opt, key): + try: + if str(value) == 'True': + setattr(self.data_opt, key, True) + print('set {} {}: {}'.format('bool', key, True)) + elif str(value) == 'False': + setattr(self.data_opt, key, False) + print('set {} {}: {}'.format('bool', key, False)) + else: + new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ]) + new_value = getattr(new_opt, key) + setattr(self.data_opt, key, new_value) + print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value)) + except Exception as e: + print('error {} - cant set value {}: {}'.format(e, key, value)) + def _get_fn(self): + return vars(self.data_opt) + def _cmd_fn(self, cmd, payload): + print("got command {}".format(cmd)) + module_name = self.opt.module_name + if module_name == 'pix2pixHD': + module_name = '' + + if cmd == 'list_checkpoints': + return list_checkpoints(payload) + if cmd == 'list_epochs': + return list_epochs(payload) + if cmd == 'list_sequences': + return list_sequences(payload) + if cmd == 'load_epoch': + name, epoch = payload.split(':') + print(">>> loading checkpoint {}, epoch {}".format(name, epoch)) + self.data_opt.checkpoint_name = name + self.data_opt.epoch = epoch + self.data_opt.load_checkpoint = True + return 'ok' + if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', module_name, payload)): + print('load sequence: {}'.format(payload)) + self.data_opt.sequence_name = payload + self.data_opt.load_sequence = True + return 'loaded sequence' + if cmd == 'seek': + self.data_opt.seek_to = payload + return 'seeked' + if cmd == 'get_status': + return { + 'processing': self.data_opt.processing, + 'checkpoint': self.data_opt.checkpoint_name, + 'epoch': self.data_opt.epoch, + 'sequence': self.data_opt.sequence_name, + } + if cmd == 'play' and self.data_opt.processing is False: + self.data_opt.pause = False + self.run_live(self.opt, self.data_opt, self.rpc_client) + return 'playing' + if cmd == 'pause' and self.data_opt.processing is True: + self.data_opt.pause = True + return 'paused' + if cmd == 'exit': + print("Exiting now...!") + sys.exit(0) + return 'exited' + return 'ok' + def _ready_fn(self, rpc_client): + print("Ready!") + self.rpc_client = rpc_client + self.run_live(self.opt, self.data_opt, rpc_client) + def connect(self): + self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn) + +def read_sequence(path, module_name): + print("> read sequence {}".format(path)) + return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))]) |
