diff options
| author | tingchunw <tingchunw@nvidia.com> | 2017-12-09 01:02:39 +0000 |
|---|---|---|
| committer | tingchunw <tingchunw@nvidia.com> | 2017-12-09 01:02:39 +0000 |
| commit | db4a24df4a8482089bcb88d0e79c9b484307fc75 (patch) | |
| tree | 50f05acb700ff3b71e94d937ccd1b57da165cea0 /util/util.py | |
| parent | 99d031b469478434ea185e1da07f12b7b007c6b6 (diff) | |
add explanation for training with new dataset
Diffstat (limited to 'util/util.py')
| -rwxr-xr-x | util/util.py | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/util/util.py b/util/util.py index 95d1315..f5ed60a 100755 --- a/util/util.py +++ b/util/util.py @@ -24,13 +24,15 @@ def tensor2im(image_tensor, imtype=np.uint8, normalize=True): return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map -def tensor2label(output, n_label, imtype=np.uint8): - output = output.cpu().float() - if output.size()[0] > 1: - output = output.max(0, keepdim=True)[1] - output = Colorize(n_label)(output) - output = np.transpose(output.numpy(), (1, 2, 0)) - return output.astype(imtype) +def tensor2label(label_tensor, n_label, imtype=np.uint8): + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + return label_numpy.astype(imtype) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) |
