summaryrefslogtreecommitdiff
path: root/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'rpc')
-rw-r--r--rpc/img_ops.py151
-rw-r--r--rpc/listener.py110
2 files changed, 261 insertions, 0 deletions
diff --git a/rpc/img_ops.py b/rpc/img_ops.py
new file mode 100644
index 0000000..07e0bf1
--- /dev/null
+++ b/rpc/img_ops.py
@@ -0,0 +1,151 @@
+import numpy as np
+import cv2
+import math
+
+def clamp(n,a,b):
+ return max(a, min(n, b))
+
+def lerp(n,a,b):
+ return (b-a)*n+a
+
+def process_image(opt, data_opt, im):
+ img = im[:, :, ::-1].copy()
+ processed = False
+
+ if data_opt.process_frac == 0:
+ return img
+
+ if data_opt.clahe is True:
+ processed = True
+ lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
+ l, a, b = cv2.split(lab)
+ clahe = cv2.createCLAHE(clipLimit=data_opt.clip_limit, tileGridSize=(8,8))
+ l = clahe.apply(l)
+ limg = cv2.merge((l,a,b))
+ img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR)
+ if data_opt.posterize is True:
+ processed = True
+ img = cv2.pyrMeanShiftFiltering(img, data_opt.spatial_window, data_opt.color_window)
+ if data_opt.grayscale is True:
+ processed = True
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if data_opt.blur is True:
+ processed = True
+ img = cv2.GaussianBlur(img, (data_opt.blur_radius, data_opt.blur_radius), data_opt.blur_sigma)
+ if data_opt.canny is True:
+ processed = True
+ img = cv2.Canny(img, data_opt.canny_lo, data_opt.canny_hi)
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if processed is False:
+ return img
+
+ src_img = im[:, :, ::-1].copy()
+ frac_a = data_opt.process_frac
+ frac_b = 1.0 - frac_a
+ array_a = np.multiply(src_img.astype('float64'), frac_a)
+ array_b = np.multiply(img.astype('float64'), frac_b)
+ img = np.add(array_a, array_b).astype('uint8')
+ return img
+
+last_im = None
+def mix_next_image(opt, data_opt, im, i, sequence, sequence_i):
+ global last_im
+
+ if (i % 100) == 0:
+ print('%04d: process image...' % (i))
+
+ last_path = opt.render_dir + "frame_{:05d}.png".format(i)
+ tmp_path = opt.render_dir + "frame_{:05d}_tmp.png".format(i+1)
+ next_path = opt.render_dir + "frame_{:05d}.png".format(i+1)
+ current_path = opt.render_dir + "ren_{:05d}.png".format(i+1)
+ meta = { 'i': i, 'sequence_i': sequence_i, 'sequence_len': len(sequence) }
+ if data_opt.sequence and len(sequence):
+ sequence_path = sequence[sequence_i]
+ if sequence_i >= len(sequence)-1:
+ print('(((( sequence looped ))))')
+ sequence_i = 1
+ else:
+ sequence_i += 1
+
+ if data_opt.store_a is not True:
+ os.remove(last_path)
+
+ if data_opt.store_b is True:
+ image_pil = Image.fromarray(im, mode='RGB')
+ image_pil.save(tmp_path)
+ os.rename(tmp_path, current_path)
+
+ if data_opt.recursive and last_im is not None:
+ if data_opt.sequence and len(sequence):
+ A_img = Image.open(sequence_path).convert('RGB')
+ A_im = np.asarray(A_img)
+ frac_a = data_opt.recursive_frac
+ frac_b = data_opt.sequence_frac
+ frac_sum = frac_a + frac_b
+ if frac_sum > 1.0:
+ frac_a = frac_a / frac_sum
+ frac_b = frac_b / frac_sum
+ if data_opt.transition:
+ t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max)
+ frac_a *= 1.0 - t
+ frac_b *= 1.0 - t
+ frac_c = 1.0 - frac_a - frac_b
+ array_a = np.multiply(last_im.astype('float64'), frac_a)
+ array_b = np.multiply(A_im.astype('float64'), frac_b)
+ array_c = np.multiply(im.astype('float64'), frac_c)
+ array_ab = np.add(array_a, array_b)
+ array_abc = np.add(array_ab, array_c)
+ next_im = array_abc.astype('uint8')
+
+ else:
+ frac_a = data_opt.recursive_frac
+ frac_b = 1.0 - frac_a
+ array_a = np.multiply(last_im.astype('float64'), frac_a)
+ array_b = np.multiply(im.astype('float64'), frac_b)
+ next_im = np.add(array_a, array_b).astype('uint8')
+
+ if data_opt.recurse_roll != 0:
+ last_im = np.roll(im, data_opt.recurse_roll, axis=data_opt.recurse_roll_axis)
+ else:
+ last_im = next_im.copy().astype('uint8')
+
+ elif data_opt.sequence and len(sequence):
+ A_img = Image.open(sequence_path).convert('RGB')
+ A_im = np.asarray(A_img)
+ frac_b = data_opt.sequence_frac
+ if data_opt.transition:
+ t = lerp(math.sin(i / data_opt.transition_period * math.pi * 2.0 ) / 2.0 + 0.5, data_opt.transition_min, data_opt.transition_max)
+ frac_b *= 1.0 - t
+ frac_c = 1.0 - frac_b
+ array_b = np.multiply(A_im.astype('float64'), frac_b)
+ array_c = np.multiply(im.astype('float64'), frac_c)
+ array_bc = np.add(array_b, array_c)
+ next_im = array_bc.astype('uint8')
+
+ else:
+ last_im = im.copy().astype('uint8')
+ next_im = im
+
+ next_img = process_image(opt, data_opt, next_im)
+
+ if data_opt.send_image == 'sequence':
+ rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, A_img, data_opt.output_format)
+ if data_opt.send_image == 'recursive':
+ pil_im = Image.fromarray(next_im)
+ rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format)
+ if data_opt.send_image == 'a':
+ rgb_im = cv2.cvtColor(next_img, cv2.COLOR_BGR2RGB)
+ pil_im = Image.fromarray(rgb_im)
+ rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, pil_im, data_opt.output_format)
+ if data_opt.send_image == 'b':
+ image_pil = Image.fromarray(im, mode='RGB')
+ rpc_client.send_pil_image("frame_{:05d}.png".format(i+1), meta, image_pil, data_opt.output_format)
+
+ cv2.imwrite(tmp_path, next_img)
+ os.rename(tmp_path, next_path)
+
+ if (i % 20) == 0:
+ print("created {}".format(next_path))
+
+ return sequence_i \ No newline at end of file
diff --git a/rpc/listener.py b/rpc/listener.py
new file mode 100644
index 0000000..a9c571d
--- /dev/null
+++ b/rpc/listener.py
@@ -0,0 +1,110 @@
+import os
+import sys
+
+from rpc import CortexRPC
+from img_ops import process_image
+
+def list_checkpoints(payload):
+ print("> list checkpoints")
+ return sorted([f.split('/')[3] for f in glob.glob('./checkpoints/' + payload + '/*/latest_net_G.pth')])
+
+def list_epochs(path):
+ print("> list epochs for {}".format(path))
+ if not os.path.exists(os.path.join('./checkpoints/', path)):
+ return "not found"
+ return sorted([f.split('/')[4].replace('_net_G.pth','') for f in glob.glob('./checkpoints/' + path + '/*_net_G.pth')])
+
+def list_sequences(module):
+ print("> list sequences")
+ sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
+ results = []
+ for path in sequences:
+ count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
+ results.append({
+ 'name': path,
+ 'count': count,
+ })
+ return results
+
+class Listener():
+ def __init__(self, opt, data_opt, data_opt_parser, run_live):
+ self.opt = opt
+ self.data_opt = data_opt
+ self.data_opt_parser = data_opt_parser.parser
+ self.data_opt.load_checkpoint = True
+ self.working = False
+ self.run_live = run_live
+ def _set_fn(self, key, value):
+ if hasattr(self.data_opt, key):
+ try:
+ if str(value) == 'True':
+ setattr(self.data_opt, key, True)
+ print('set {} {}: {}'.format('bool', key, True))
+ elif str(value) == 'False':
+ setattr(self.data_opt, key, False)
+ print('set {} {}: {}'.format('bool', key, False))
+ else:
+ new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ])
+ new_value = getattr(new_opt, key)
+ setattr(self.data_opt, key, new_value)
+ print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value))
+ except Exception as e:
+ print('error {} - cant set value {}: {}'.format(e, key, value))
+ def _get_fn(self):
+ return vars(self.data_opt)
+ def _cmd_fn(self, cmd, payload):
+ print("got command {}".format(cmd))
+ module_name = self.opt.module_name
+ if module_name == 'pix2pixHD':
+ module_name = ''
+
+ if cmd == 'list_checkpoints':
+ return list_checkpoints(payload)
+ if cmd == 'list_epochs':
+ return list_epochs(payload)
+ if cmd == 'list_sequences':
+ return list_sequences(payload)
+ if cmd == 'load_epoch':
+ name, epoch = payload.split(':')
+ print(">>> loading checkpoint {}, epoch {}".format(name, epoch))
+ self.data_opt.checkpoint_name = name
+ self.data_opt.epoch = epoch
+ self.data_opt.load_checkpoint = True
+ return 'ok'
+ if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', module_name, payload)):
+ print('load sequence: {}'.format(payload))
+ self.data_opt.sequence_name = payload
+ self.data_opt.load_sequence = True
+ return 'loaded sequence'
+ if cmd == 'seek':
+ self.data_opt.seek_to = payload
+ return 'seeked'
+ if cmd == 'get_status':
+ return {
+ 'processing': self.data_opt.processing,
+ 'checkpoint': self.data_opt.checkpoint_name,
+ 'epoch': self.data_opt.epoch,
+ 'sequence': self.data_opt.sequence_name,
+ }
+ if cmd == 'play' and self.data_opt.processing is False:
+ self.data_opt.pause = False
+ self.run_live(self.opt, self.data_opt, self.rpc_client)
+ return 'playing'
+ if cmd == 'pause' and self.data_opt.processing is True:
+ self.data_opt.pause = True
+ return 'paused'
+ if cmd == 'exit':
+ print("Exiting now...!")
+ sys.exit(0)
+ return 'exited'
+ return 'ok'
+ def _ready_fn(self, rpc_client):
+ print("Ready!")
+ self.rpc_client = rpc_client
+ self.run_live(self.opt, self.data_opt, rpc_client)
+ def connect(self):
+ self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)
+
+def read_sequence(path, module_name):
+ print("> read sequence {}".format(path))
+ return sorted([f for f in glob.glob(os.path.join('./sequences/', module_name, path, '*.png'))])