summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcam <cameron@ideum.com>2016-10-22 19:52:58 -0600
committercam <cameron@ideum.com>2016-10-22 19:52:58 -0600
commit12f7c0fa5b9bc23deac65aae71c46be611adec97 (patch)
treef8eda1f4effc5136dbb4e6cef566200a0efc6e6a
parentdff0e16b9104cae053d9891689f656fd6d60de51 (diff)
Added more colorspaces for color conversion
-rw-r--r--neural_style.py35
1 files changed, 26 insertions, 9 deletions
diff --git a/neural_style.py b/neural_style.py
index d709025..4af4bca 100644
--- a/neural_style.py
+++ b/neural_style.py
@@ -91,6 +91,11 @@ def parse_args():
parser.add_argument('--original_colors', action='store_true',
help='Transfer the style but not the colors.')
+ parser.add_argument('--color_convert_type', type=str,
+ default='yuv',
+ choices=['yuv', 'ycrcb', 'luv'],
+ help='Color space for conversion to original colors (default: %(default)s)')
+
parser.add_argument('--style_mask', action='store_true',
help='Transfer the style to masked regions.')
@@ -481,7 +486,7 @@ def sum_total_variation_losses(sess, net, input_img):
'''
def read_image(path):
# bgr image
- img = cv2.imread(path, cv2.IMREAD_COLOR).astype('float')
+ img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32)
img = preprocess(img, vgg19_mean)
return img
@@ -701,7 +706,7 @@ def get_content_frame(frame):
def get_content_image(content_img):
path = os.path.join(args.content_img_dir, content_img)
# bgr image
- img = cv2.imread(path, cv2.IMREAD_COLOR).astype('float')
+ img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32)
h, w, d = img.shape
mx = args.max_size
# resize if > max size
@@ -755,7 +760,7 @@ def get_prev_warped_frame(frame):
flow_fn = args.backward_optical_flow_frmt.format(str(frame), str(prev_frame))
flow_path = os.path.join(args.video_input_dir, flow_fn)
flow = read_flow_file(flow_path)
- warped_img = warp_image(prev_img, flow).astype('float32')
+ warped_img = warp_image(prev_img, flow).astype(np.float32)
img = preprocess(warped_img, vgg19_mean)
return img
@@ -786,12 +791,24 @@ def warp_image(src, flow):
def convert_to_original_colors(content_img, stylized_img):
content_img = postprocess(content_img, vgg19_mean)
stylized_img = postprocess(stylized_img, vgg19_mean)
- content_yuv = cv2.cvtColor(content_img, cv2.COLOR_BGR2YUV)
- stylized_yuv = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2YUV)
- y, _, _ = cv2.split(stylized_yuv)
- _, u, v = cv2.split(content_yuv)
- merged = cv2.merge((y, u, v))
- dst = cv2.cvtColor(merged, cv2.COLOR_YUV2BGR).astype('float')
+ if args.color_convert_type == 'yuv':
+ content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2YUV)
+ stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2YUV)
+ elif args.color_convert_type == 'ycrcb':
+ content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2YCR_CB)
+ stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2YCR_CB)
+ elif args.color_convert_type == 'luv':
+ content_cvt = cv2.cvtColor(content_img, cv2.COLOR_BGR2LUV)
+ stylized_cvt = cv2.cvtColor(stylized_img, cv2.COLOR_BGR2LUV)
+ c1, _, _ = cv2.split(stylized_cvt)
+ _, c2, c3 = cv2.split(content_cvt)
+ merged = cv2.merge((c1, c2, c3))
+ if args.color_convert_type == 'yuv':
+ dst = cv2.cvtColor(merged, cv2.COLOR_YUV2BGR).astype(np.float32)
+ elif args.color_convert_type == 'ycrcb':
+ dst = cv2.cvtColor(merged, cv2.COLOR_YCR_CB2BGR).astype(np.float32)
+ elif args.color_convert_type == 'luv':
+ dst = cv2.cvtColor(merged, cv2.COLOR_LUV2BGR).astype(np.float32)
dst = preprocess(dst, vgg19_mean)
return dst