summaryrefslogtreecommitdiff
path: root/util/util.py
diff options
context:
space:
mode:
authorjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
committerjunyanz <junyanz@berkeley.edu>2017-04-18 03:38:47 -0700
commitc99ce7c4e781712e0252c6127ad1a4e8021cc489 (patch)
treeba99dfd56a47036d9c1f18620abf4efc248839ab /util/util.py
first commit
Diffstat (limited to 'util/util.py')
-rw-r--r--util/util.py71
1 files changed, 71 insertions, 0 deletions
diff --git a/util/util.py b/util/util.py
new file mode 100644
index 0000000..781239f
--- /dev/null
+++ b/util/util.py
@@ -0,0 +1,71 @@
+from __future__ import print_function
+import torch
+import numpy as np
+from PIL import Image
+import inspect, re
+import numpy as np
+import os
+import collections
+
+# Converts a Tensor into a Numpy array
+# |imtype|: the desired type of the converted numpy array
+def tensor2im(image_tensor, imtype=np.uint8):
+ image_numpy = image_tensor[0].cpu().float().numpy()
+ image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
+ return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+ mean = 0.0
+ count = 0
+ for param in net.parameters():
+ if param.grad is not None:
+ mean += torch.mean(torch.abs(param.grad.data))
+ count += 1
+ if count > 0:
+ mean = mean / count
+ print(name)
+ print(mean)
+
+
+def save_image(image_numpy, image_path):
+ image_pil = Image.fromarray(image_numpy)
+ image_pil.save(image_path)
+
+def info(object, spacing=10, collapse=1):
+ """Print methods and doc strings.
+ Takes module, class, list, dictionary, or string."""
+ methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
+ processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
+ print( "\n".join(["%s %s" %
+ (method.ljust(spacing),
+ processFunc(str(getattr(object, method).__doc__)))
+ for method in methodList]) )
+
+def varname(p):
+ for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
+ m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
+ if m:
+ return m.group(1)
+
+def print_numpy(x, val=True, shp=False):
+ x = x.astype(np.float64)
+ if shp:
+ print('shape,', x.shape)
+ if val:
+ x = x.flatten()
+ print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+ np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)