1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
|
import os
import sys
import glob
from rpc import CortexRPC
from img_ops import process_image
def list_checkpoints(payload):
print("> list checkpoints")
return sorted([f.split('/')[3] for f in glob.glob(os.path.join('./checkpoints/', payload, '*', 'latest_net_G.pth'))])
def list_all_checkpoints(payload):
print("> list all checkpoints")
return sorted(glob.glob(os.path.join('./checkpoints/*/*_net_G.pth')))
def list_epochs(path):
print("> list epochs for {}".format(path))
if not os.path.exists(os.path.join(os.getcwd(), 'checkpoints', path)):
print('not found')
return "not found"
print(os.getcwd())
print(os.path.join('./checkpoints/', path))
print(glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth')))
return sorted([os.path.basename(f).replace('_net_G.pth', '') for f in glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth'))])
def list_sequences(module):
print("> list sequences")
sequences = sorted([name for name in os.listdir(os.path.join('./sequences/', module)) if os.path.isdir(os.path.join('./sequences/', module, name))])
results = []
for path in sequences:
count = len([name for name in os.listdir(os.path.join('./sequences/', module, path)) if os.path.isfile(os.path.join('./sequences/', module, path, name))])
results.append({
'name': path,
'count': count,
})
return results
class Listener():
def __init__(self, opt, data_opt, data_opt_parser, run_live):
self.opt = opt
self.data_opt = data_opt
self.data_opt_parser = data_opt_parser.parser
self.data_opt.load_checkpoint = True
self.working = False
self.run_live = run_live
def _set_fn(self, key, value):
if hasattr(self.data_opt, key):
try:
if str(value) == 'True':
setattr(self.data_opt, key, True)
print('set {} {}: {}'.format('bool', key, True))
elif str(value) == 'False':
setattr(self.data_opt, key, False)
print('set {} {}: {}'.format('bool', key, False))
else:
new_opt, misc = self.data_opt_parser.parse_known_args([ '--' + key.replace('_', '-'), str(value) ])
new_value = getattr(new_opt, key)
setattr(self.data_opt, key, new_value)
print('set {} {}: {} [{}]'.format(type(new_value), key, new_value, value))
except Exception as e:
print('error {} - cant set value {}: {}'.format(e, key, value))
def _get_fn(self):
return vars(self.data_opt)
def _cmd_fn(self, cmd, payload):
print("got command {}".format(cmd))
module_name = self.opt.module_name
if module_name == 'pix2pixHD':
module_name = ''
if cmd == 'list_checkpoints':
return list_checkpoints(payload)
if cmd == 'list_all_checkpoints':
return list_all_checkpoints(payload)
if cmd == 'list_epochs':
return list_epochs(payload)
if cmd == 'list_sequences':
return list_sequences(payload)
if cmd == 'load_epoch':
name, epoch = payload.split(':')
print(">>> loading checkpoint {}, epoch {}".format(name, epoch))
self.data_opt.checkpoint_name = name
self.data_opt.epoch = epoch
self.data_opt.load_checkpoint = True
return 'ok'
if cmd == 'load_sequence' and os.path.exists(os.path.join('./sequences/', module_name, payload)):
print('load sequence: {}'.format(payload))
self.data_opt.sequence_name = payload
self.data_opt.load_sequence = True
return 'loaded sequence'
if cmd == 'seek':
self.data_opt.seek_to = payload
return 'seeked'
if cmd == 'get_status':
return {
'processing': self.data_opt.processing,
'checkpoint': self.data_opt.checkpoint_name,
'epoch': self.data_opt.epoch,
'sequence': self.data_opt.sequence_name,
}
if cmd == 'play' and self.data_opt.processing is False:
self.data_opt.pause = False
self.run_live(self.opt, self.data_opt, self.rpc_client)
return 'playing'
if cmd == 'pause' and self.data_opt.processing is True:
self.data_opt.pause = True
return 'paused'
if cmd == 'exit':
print("Exiting now...!")
sys.exit(0)
return 'exited'
return 'ok'
def _ready_fn(self, rpc_client):
print("Ready!")
self.rpc_client = rpc_client
self.run_live(self.opt, self.data_opt, rpc_client)
def connect(self):
self.rpc_client = CortexRPC(self._get_fn, self._set_fn, self._ready_fn, self._cmd_fn)
|