From db4a24df4a8482089bcb88d0e79c9b484307fc75 Mon Sep 17 00:00:00 2001 From: tingchunw Date: Sat, 9 Dec 2017 01:02:39 +0000 Subject: add explanation for training with new dataset --- util/util.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'util/util.py') diff --git a/util/util.py b/util/util.py index 95d1315..f5ed60a 100755 --- a/util/util.py +++ b/util/util.py @@ -24,13 +24,15 @@ def tensor2im(image_tensor, imtype=np.uint8, normalize=True): return image_numpy.astype(imtype) # Converts a one-hot tensor into a colorful label map -def tensor2label(output, n_label, imtype=np.uint8): - output = output.cpu().float() - if output.size()[0] > 1: - output = output.max(0, keepdim=True)[1] - output = Colorize(n_label)(output) - output = np.transpose(output.numpy(), (1, 2, 0)) - return output.astype(imtype) +def tensor2label(label_tensor, n_label, imtype=np.uint8): + if n_label == 0: + return tensor2im(label_tensor, imtype) + label_tensor = label_tensor.cpu().float() + if label_tensor.size()[0] > 1: + label_tensor = label_tensor.max(0, keepdim=True)[1] + label_tensor = Colorize(n_label)(label_tensor) + label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) + return label_numpy.astype(imtype) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) -- cgit v1.2.3-70-g09d2