summaryrefslogtreecommitdiff
path: root/util/util.py
diff options
context:
space:
mode:
authortingchunw <tingchunw@nvidia.com>2017-12-09 01:02:39 +0000
committertingchunw <tingchunw@nvidia.com>2017-12-09 01:02:39 +0000
commitdb4a24df4a8482089bcb88d0e79c9b484307fc75 (patch)
tree50f05acb700ff3b71e94d937ccd1b57da165cea0 /util/util.py
parent99d031b469478434ea185e1da07f12b7b007c6b6 (diff)
add explanation for training with new dataset
Diffstat (limited to 'util/util.py')
-rwxr-xr-xutil/util.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/util/util.py b/util/util.py
index 95d1315..f5ed60a 100755
--- a/util/util.py
+++ b/util/util.py
@@ -24,13 +24,15 @@ def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
-def tensor2label(output, n_label, imtype=np.uint8):
- output = output.cpu().float()
- if output.size()[0] > 1:
- output = output.max(0, keepdim=True)[1]
- output = Colorize(n_label)(output)
- output = np.transpose(output.numpy(), (1, 2, 0))
- return output.astype(imtype)
+def tensor2label(label_tensor, n_label, imtype=np.uint8):
+ if n_label == 0:
+ return tensor2im(label_tensor, imtype)
+ label_tensor = label_tensor.cpu().float()
+ if label_tensor.size()[0] > 1:
+ label_tensor = label_tensor.max(0, keepdim=True)[1]
+ label_tensor = Colorize(n_label)(label_tensor)
+ label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
+ return label_numpy.astype(imtype)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)