summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcam <cameron@ideum.com>2016-10-22 19:59:24 -0600
committercam <cameron@ideum.com>2016-10-22 19:59:24 -0600
commita4909f47bd2672edadfeded8f4213ea6d378baae (patch)
treeaa302dbe6cdb141aa7bee19c204e2971baa9da90
parent12f7c0fa5b9bc23deac65aae71c46be611adec97 (diff)
Simplified color conversion
-rw-r--r--neural_style.py21
1 files changed, 9 insertions, 12 deletions
diff --git a/neural_style.py b/neural_style.py
index 4af4bca..56916fa 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -792,23 +792,20 @@ def convert_to_original_colors(content_img, stylized_img):
content_img = postprocess(content_img, vgg19_mean)
stylized_img = postprocess(stylized_img, vgg19_mean)
if args.color_convert_type == 'yuv':
- content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2YUV)
- stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2YUV)
+ cvt_type = cv2.COLOR_BGR2YUV
+ inv_cvt_type = cv2.COLOR_YUV2BGR
elif args.color_convert_type == 'ycrcb':
- content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2YCR_CB)
- stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2YCR_CB)
+ cvt_type = cv2.COLOR_BGR2YCR_CB
+ inv_cvt_type = cv2.COLOR_YCR_CB2BGR
elif args.color_convert_type == 'luv':
- content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2LUV)
- stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2LUV)
+ cvt_type = cv2.COLOR_BGR2LUV
+ inv_cvt_type = cv2.COLOR_LUV2BGR
+ content_cvt = cv2.cvtColor(content_img, cvt_type)
+ stylized_cvt = cv2.cvtColor(stylized_img, cvt_type)
c1, _, _ = cv2.split(stylized_cvt)
_, c2, c3 = cv2.split(content_cvt)
merged = cv2.merge((c1, c2, c3))
- if args.color_convert_type == 'yuv':
- dst = cv2.cvtColor(merged, cv2.COLOR_YUV2BGR).astype(np.float32)
- elif args.color_convert_type == 'ycrcb':
- dst = cv2.cvtColor(merged, cv2.COLOR_YCR_CB2BGR).astype(np.float32)
- elif args.color_convert_type == 'luv':
- dst = cv2.cvtColor(merged, cv2.COLOR_LUV2BGR).astype(np.float32)
+ dst = cv2.cvtColor(merged, inv_cvt_type).astype(np.float32)
dst = preprocess(dst, vgg19_mean)
return dst