summaryrefslogtreecommitdiff
path: root/rpc
diff options
context:
space:
mode:
Diffstat (limited to 'rpc')
-rw-r--r--rpc/listener.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/rpc/listener.py b/rpc/listener.py
index 81ad794..1df575e 100644
--- a/rpc/listener.py
+++ b/rpc/listener.py
@@ -7,7 +7,7 @@ from img_ops import process_image
def list_checkpoints(payload):
print("> list checkpoints")
- return sorted([f.split('/')[3] for f in glob.glob(os.path.join('./checkpoints/', payload, '/*/latest_net_G.pth'))])
+ return sorted([f.split('/')[3] for f in glob.glob(os.path.join('./checkpoints/', payload, '*', 'latest_net_G.pth'))])
def list_all_checkpoints(payload):
print("> list all checkpoints")
@@ -15,9 +15,13 @@ def list_all_checkpoints(payload):
def list_epochs(path):
print("> list epochs for {}".format(path))
- if not os.path.exists(os.path.join('./checkpoints/', path)):
+ if not os.path.exists(os.path.join(os.getcwd(), 'checkpoints', path)):
+ print('not found')
return "not found"
- return sorted([os.path.basename(f).replace('_net_G.pth', '') for f in glob.glob(os.path.join('./checkpoints/', path, '/*_net_G.pth'))])
+ print(os.getcwd())
+ print(os.path.join('./checkpoints/', path))
+ print(glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth')))
+ return sorted([os.path.basename(f).replace('_net_G.pth', '') for f in glob.glob(os.path.join(os.getcwd(), 'checkpoints', path, '*_net_G.pth'))])
def list_sequences(module):
print("> list sequences")