import os import sys import glob from rpc import CortexRPC from img_ops import process_image def list_checkpoints(payload): print("> list checkpoints") return sorted([f.split('/')[3] for f in glob.glob(os.path.join('./checkpoints/', payload, '*', 'latest_net_G.pth'))]) def list_all_checkpoints(payload): print("> list all checkpoints") return sorted(glob.glob(os.path.join('./checkpoints/*/*_net_G.pth'))) def list_epochs(path): print("> list epochs for {}".format(path)) if not os.path.exists(os.path.join(os.getcwd(), 'checkpoints', path)): print('not found') return "not found" print(os.getcwd()) print(os.path.join('./checkpoints/', path)) print(glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth'))) return sorted([os.path.basename(f).replace('_net_G.pth', '') for f in glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth'))]) def list_sequences(module): print("> list sequences") sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))]) results = [] for path in sequences: count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))]) results.append({ 'name': path, 'count': count, }) return results class Listener(): def __init__(self, opt, data_opt, data_opt_parser, run_live): self.opt = opt self.data_opt = data_opt self.data_opt_parser = data_opt_parser.parser self.data_opt.load_checkpoint = True self.working = False self.run_live = run_live def _set_fn(self, key, value): if hasattr(self.data_opt, key): try: if str(value) == 'True': setattr(self.data_opt, key, True) print('set {} {}: {}'.format('bool', key, True)) elif str(value) == 'False': setattr(self.data_opt, key, False) print('set {} {}: {}'.format('bool', key, False)) else: new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ]) new_value = getattr(new_opt, key) setattr(self.data_opt, key, new_value) print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value)) except Exception as e: print('error {} - cant set value {}: {}'.format(e, key, value)) def _get_fn(self): return vars(self.data_opt) def _cmd_fn(self, cmd, payload): print("got command {}".format(cmd)) module_name = self.opt.module_name if module_name == 'pix2pixHD': module_name = '' if cmd == 'list_checkpoints': return list_checkpoints(payload) if cmd == 'list_all_checkpoints': return list_all_checkpoints(payload) if cmd == 'list_epochs': return list_epochs(payload) if cmd == 'list_sequences': return list_sequences(payload) if cmd == 'load_epoch': name, epoch = payload.split(':') print(">>> loading checkpoint {}, epoch {}".format(name, epoch)) self.data_opt.checkpoint_name = name self.data_opt.epoch = epoch self.data_opt.load_checkpoint = True return 'ok' if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', module_name, payload)): print('load sequence: {}'.format(payload)) self.data_opt.sequence_name = payload self.data_opt.load_sequence = True return 'loaded sequence' if cmd == 'seek': self.data_opt.seek_to = payload return 'seeked' if cmd == 'get_status': return { 'processing': self.data_opt.processing, 'checkpoint': self.data_opt.checkpoint_name, 'epoch': self.data_opt.epoch, 'sequence': self.data_opt.sequence_name, } if cmd == 'play' and self.data_opt.processing is False: self.data_opt.pause = False self.run_live(self.opt, self.data_opt, self.rpc_client) return 'playing' if cmd == 'pause' and self.data_opt.processing is True: self.data_opt.pause = True return 'paused' if cmd == 'exit': print("Exiting now...!") sys.exit(0) return 'exited' return 'ok' def _ready_fn(self, rpc_client): print("Ready!") self.rpc_client = rpc_client self.run_live(self.opt, self.data_opt, rpc_client) def connect(self): self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)