summaryrefslogtreecommitdiff
path: root/rpc/listener.py
blob: 4b91230380af272ee649b59087254a12955e3b30 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import glob

from rpc import CortexRPC
from img_ops import process_image

def list_checkpoints(payload):
  print("> list checkpoints")
  return sorted([f.split('/')[3] for f in glob.glob(os.path.join('./checkpoints/', payload, '*', 'latest_net_G.pth'))])

def list_all_checkpoints(payload):
  print("> list all checkpoints")
  return sorted(glob.glob(os.path.join('./checkpoints/*/*_net_G.pth')))

def list_epochs(path):
  print("> list epochs for {}".format(path))
  if not os.path.exists(os.path.join(os.getcwd(), 'checkpoints', path)):
    print('not found')
    return "not found"
  print(os.getcwd())
  print(os.path.join('./checkpoints/', path))
  print(glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth')))
  return sorted([os.path.basename(f).replace('_net_G.pth', '') for f in glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth'))])

def list_sequences(module):
  print("> list sequences")
  sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
  results = []
  for path in sequences:
    count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
    results.append({
      'name': path,
      'count': count,
    })
  return results

class Listener():
  def __init__(self, opt, data_opt, data_opt_parser, run_live):
    self.opt = opt
    self.data_opt = data_opt
    self.data_opt_parser = data_opt_parser.parser
    self.data_opt.load_checkpoint = True
    self.working = False
    self.run_live = run_live
  def _set_fn(self, key, value):
    if hasattr(self.data_opt, key):
      try:
        if str(value) == 'True':
          setattr(self.data_opt, key, True)
          print('set {} {}: {}'.format('bool', key, True))
        elif str(value) == 'False':
          setattr(self.data_opt, key, False)
          print('set {} {}: {}'.format('bool', key, False))
        else:
          new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ])
          new_value = getattr(new_opt, key)
          setattr(self.data_opt, key, new_value)
          print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value))
      except Exception as e:
        print('error {} - cant set value {}: {}'.format(e, key, value))
  def _get_fn(self):
    return vars(self.data_opt)
  def _cmd_fn(self, cmd, payload):
    print("got command {}".format(cmd))
    module_name = self.opt.module_name
    if module_name == 'pix2pixHD':
      module_name = ''

    if cmd == 'list_checkpoints':
      return list_checkpoints(payload)
    if cmd == 'list_all_checkpoints':
      return list_all_checkpoints(payload)
    if cmd == 'list_epochs':
      return list_epochs(payload)
    if cmd == 'list_sequences':
      return list_sequences(payload)
    if cmd == 'load_epoch':
      name, epoch = payload.split(':')
      print(">>> loading checkpoint {}, epoch {}".format(name, epoch))
      self.data_opt.checkpoint_name = name
      self.data_opt.epoch = epoch
      self.data_opt.load_checkpoint = True
      return 'ok'
    if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', module_name, payload)):
      print('load sequence: {}'.format(payload))
      self.data_opt.sequence_name = payload
      self.data_opt.load_sequence = True
      return 'loaded sequence'
    if cmd == 'seek':
      self.data_opt.seek_to = payload
      return 'seeked'
    if cmd == 'get_status':
      return {
        'processing': self.data_opt.processing,
        'checkpoint': self.data_opt.checkpoint_name,
        'epoch': self.data_opt.epoch,
        'sequence': self.data_opt.sequence_name,
      }
    if cmd == 'play' and self.data_opt.processing is False:
      self.data_opt.pause = False
      self.run_live(self.opt, self.data_opt, self.rpc_client)
      return 'playing'
    if cmd == 'pause' and self.data_opt.processing is True:
      self.data_opt.pause = True
      return 'paused'
    if cmd == 'exit':
      print("Exiting now...!")
      sys.exit(0)
      return 'exited'
    return 'ok'
  def _ready_fn(self, rpc_client):
    print("Ready!")
    self.rpc_client = rpc_client
    self.run_live(self.opt, self.data_opt, rpc_client)
  def connect(self):
    self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)