summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore7
-rw-r--r--README.md34
-rw-r--r--nets.py150
-rw-r--r--test.py164
-rw-r--r--utils.py229
5 files changed, 584 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..85d7aee
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+
+inputs
+supple
+results
+.DS_Store
+*.h5
+
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f1e6019
--- /dev/null
+++ b/README.md
@@ -0,0 +1,34 @@
+# Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation
+
+This is a tensorflow implementation of the paper. [PDF](http://yhjo09.github.io/files/VSR-DUF_CVPR18.pdf)
+
+## directory
+`./inputs/G/` Ground-truth video frames
+`./inputs/L/` Low-resolution video frames
+
+`./results/<L>L/G/` Outputs from given ground-truth video frames using <L> depth network
+`./results/<L>L/L/` Outputs from given low-resolution video frames using <L> depth network
+
+## test
+Put your video frames to the input directory and run `test.py` with arguments `<L>` and `<T>`.
+```
+python test.py <L> <T>
+```
+`<L>` is the depth of network of 16, 28, 52.
+`<T>` is the type of input frames, `G` denotes GT inputs and `L` denotes LR inputs.
+
+For example, `python test.py 16 G` super-resolve input frames in `./inputs/G/*` using `16` depth network.
+
+## video
+[![supplementary video](./supple/title.png)](./supple/VSR_supple_crf28.mp4?raw=true)
+
+## bibtex
+```
+@InProceedings{Jo_2018_CVPR,
+ author = {Jo, Younghyun and Oh, Seoung Wug and Kang, Jaeyeon and Kim, Seon Joo},
+ title = {Deep Video Super-Resolution Network Using Dynamic Upsampling Filters Without Explicit Motion Compensation},
+ booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
+ year = {2018}
+}
+```
+
diff --git a/nets.py b/nets.py
new file mode 100644
index 0000000..70ea3e9
--- /dev/null
+++ b/nets.py
@@ -0,0 +1,150 @@
+## -*- coding: utf-8 -*-
+import tensorflow as tf
+
+from utils import BatchNorm, Conv3D
+
+stp = [[0,0], [1,1], [1,1], [1,1], [0,0]]
+sp = [[0,0], [0,0], [1,1], [1,1], [0,0]]
+
+def FR_16L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 32
+ for r in range(3):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(3,6):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r
+
+def FR_28L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 16
+ for r in range(9):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(9,12):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,256,256], [1,1,1,1,1], 'VALID', name='conv2')
+
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r
+
+def FR_52L(x, is_train):
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,3,64], [1,1,1,1,1], 'VALID', name='conv1')
+
+ F = 64
+ G = 16
+ for r in range(0,21):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, stp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x, t], 4)
+ F += G
+ for r in range(21,24):
+ t = BatchNorm(x, is_train, name='Rbn'+str(r+1)+'a')
+ t = tf.nn.relu(t)
+ t = Conv3D(t, [1,1,1,F,F], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'a')
+
+ t = BatchNorm(t, is_train, name='Rbn'+str(r+1)+'b')
+ t = tf.nn.relu(t)
+ t = Conv3D(tf.pad(t, sp, mode='CONSTANT'), [3,3,3,F,G], [1,1,1,1,1], 'VALID', name='Rconv'+str(r+1)+'b')
+
+ x = tf.concat([x[:,1:-1], t], 4)
+ F += G
+
+ x = BatchNorm(x, is_train, name='fbn1')
+ x = tf.nn.relu(x)
+ x = Conv3D(tf.pad(x, sp, mode='CONSTANT'), [1,3,3,448,256], [1,1,1,1,1], 'VALID', name='conv2')
+
+ x = tf.nn.relu(x)
+
+ r = Conv3D(x, [1,1,1,256,256], [1,1,1,1,1], 'VALID', name='rconv1')
+ r = tf.nn.relu(r)
+ r = Conv3D(r, [1,1,1,256,3*16], [1,1,1,1,1], 'VALID', name='rconv2')
+
+ f = Conv3D(x, [1,1,1,256,512], [1,1,1,1,1], 'VALID', name='fconv1')
+ f = tf.nn.relu(f)
+ f = Conv3D(f, [1,1,1,512,1*5*5*16], [1,1,1,1,1], 'VALID', name='fconv2')
+
+ ds_f = tf.shape(f)
+ f = tf.reshape(f, [ds_f[0], ds_f[1], ds_f[2], ds_f[3], 25, 16])
+ f = tf.nn.softmax(f, dim=4)
+
+ return f, r \ No newline at end of file
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..3347063
--- /dev/null
+++ b/test.py
@@ -0,0 +1,164 @@
+## -*- coding: utf-8 -*-
+import tensorflow as tf
+import numpy as np
+import time
+import glob
+import scipy
+import argparse
+import os
+from PIL import Image
+
+from utils import LoadImage, DownSample, AVG_PSNR, depth_to_space_3D, DynFilter3D, LoadParams
+from nets import FR_16L, FR_28L, FR_52L
+
+parser = argparse.ArgumentParser()
+parser.add_argument('L', metavar='L', type=int, help='Network depth: One of 16, 28, 52')
+parser.add_argument('T', metavar='T', help='Input type: L(Low-resolution) or G(Ground-truth)')
+parser.add_argument('dir', metavar='dir', help='Directory to process')
+args = parser.parse_args()
+
+# Size of input temporal radius
+T_in = 7
+# Upscaling factor
+R = 4
+# Selecting filters and residual generating network
+if args.L == 16:
+ FR = FR_16L
+elif args.L == 28:
+ FR = FR_28L
+elif args.L == 52:
+ FR = FR_52L
+else:
+ print('Invalid network depth: {} (Must be one of 16, 28, 52)'.format(args.L))
+ exit(1)
+
+if not(args.T == 'L' or args.T =='G'):
+ print('Invalid input type: {} (Must be L(Low-resolution) or G(Ground-truth))'.format(args.T))
+ exit(1)
+
+
+def G(x, is_train):
+ # shape of x: [B,T_in,H,W,C]
+
+ # Generate filters and residual
+ # Fx: [B,1,H,W,1*5*5,R*R]
+ # Rx: [B,1,H,W,3*R*R]
+ Fx, Rx = FR(x, is_train)
+
+ x_c = []
+ for c in range(3):
+ t = DynFilter3D(x[:,T_in//2:T_in//2+1,:,:,c], Fx[:,0,:,:,:,:], [1,5,5]) # [B,H,W,R*R]
+ t = tf.depth_to_space(t, R) # [B,H*R,W*R,1]
+ x_c += [t]
+ x = tf.concat(x_c, axis=3) # [B,H*R,W*R,3]
+ x = tf.expand_dims(x, axis=1)
+
+ Rx = depth_to_space_3D(Rx, R) # [B,1,H*R,W*R,3]
+ x += Rx
+
+ return x
+
+# Gaussian kernel for downsampling
+def gkern(kernlen=13, nsig=1.6):
+ import scipy.ndimage.filters as fi
+ # create nxn zeros
+ inp = np.zeros((kernlen, kernlen))
+ # set element at the middle to one, a dirac delta
+ inp[kernlen//2, kernlen//2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter mask
+ return fi.gaussian_filter(inp, nsig)
+
+h = gkern(13, 1.6) # 13 and 1.6 for x4
+h = h[:,:,np.newaxis,np.newaxis].astype(np.float32)
+
+# Network
+H = tf.placeholder(tf.float32, shape=[None, T_in, None, None, 3])
+L_ = DownSample(H, h, R)
+L = L_[:,:,2:-2,2:-2,:] # To minimize boundary artifact
+
+is_train = tf.placeholder(tf.bool, shape=[]) # Phase ,scalar
+
+with tf.variable_scope('G') as scope:
+ GH = G(L, is_train)
+
+params_G = [v for v in tf.global_variables() if v.name.startswith('G/')]
+
+# Session
+config = tf.ConfigProto()
+config.gpu_options.allow_growth=True
+
+with tf.Session(config=config) as sess:
+ tf.global_variables_initializer().run()
+
+ # Load parameters
+ LoadParams(sess, [params_G], in_file='params_{}L_x{}.h5'.format(args.L, R))
+
+ if args.T == 'G':
+ # Test using GT videos
+ avg_psnrs = []
+ dir_inputs = glob.glob('./inputs/G/*')
+ for v in dir_inputs:
+ scene_name = v.split('/')[-1]
+ os.makedirs('./results/{}L/G/{}/'.format(args.L, scene_name), exist_ok=True)
+
+ dir_frames = glob.glob(v + '/*.png')
+ dir_frames.sort()
+
+ frames = []
+ for f in dir_frames:
+ frames.append(LoadImage(f))
+ frames = np.asarray(frames)
+ frames_padded = np.lib.pad(frames, pad_width=((T_in//2,T_in//2),(0,0),(0,0),(0,0)), mode='constant')
+ frames_padded = np.lib.pad(frames_padded, pad_width=((0,0),(8,8),(8,8),(0,0)), mode='reflect')
+
+ out_Hs = []
+ for i in range(frames.shape[0]):
+ print('Scene {}: Frame {}/{} processing'.format(scene_name, i+1, frames.shape[0]))
+ in_H = frames_padded[i:i+T_in] # select T_in frames
+ in_H = in_H[np.newaxis,:,:,:,:]
+
+ out_H = sess.run(GH, feed_dict={H: in_H, is_train: False})
+ out_H = np.clip(out_H, 0, 1)
+
+ Image.fromarray(np.around(out_H[0,0]*255).astype(np.uint8)).save('./results/{}L/G/{}/Frame{:03d}.png'.format(args.L, scene_name, i+1))
+
+ out_Hs.append(out_H[0, 0])
+ out_Hs = np.asarray(out_Hs)
+
+ avg_psnr = AVG_PSNR(((frames)*255).astype(np.uint8)/255.0, ((out_Hs)*255).astype(np.uint8)/255.0, vmin=0, vmax=1, t_border=2, sp_border=8)
+ avg_psnrs.append(avg_psnr)
+ print('Scene {}: PSNR {}'.format(scene_name, avg_psnr))
+
+ elif args.T == 'L':
+ # Test using Low-resolution videos
+ if args.dir:
+ process_dir(args.dir)
+ else:
+ dir_inputs = glob.glob('./inputs/L/*')
+ for v in dir_inputs:
+ process_dir(v)
+
+def process_dir(v):
+ scene_name = v.split('/')[-1]
+ os.mkdir('./results/{}L/L/{}/'.format(args.L, scene_name))
+
+ dir_frames = glob.glob(v + '/*.png')
+ dir_frames.sort()
+
+ frames = []
+ for f in dir_frames:
+ frames.append(LoadImage(f))
+ frames = np.asarray(frames)
+ frames_padded = np.lib.pad(frames, pad_width=((T_in//2,T_in//2),(0,0),(0,0),(0,0)), mode='constant')
+
+ for i in range(frames.shape[0]):
+ print('Scene {}: Frame {}/{} processing'.format(scene_name, i+1, frames.shape[0]))
+ in_L = frames_padded[i:i+T_in] # select T_in frames
+ in_L = in_L[np.newaxis,:,:,:,:]
+
+ out_H = sess.run(GH, feed_dict={L: in_L, is_train: False})
+ out_H = np.clip(out_H, 0, 1)
+
+ Image.fromarray(np.around(out_H[0,0]*255).astype(np.uint8)).save('./results/{}L/L/{}/Frame{:03d}.png'.format(args.L, scene_name, i+1))
+
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..26837cd
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,229 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import tensorflow as tf
+import h5py
+from PIL import Image
+
+def LoadImage(path, color_mode='RGB', channel_mean=None, modcrop=[0,0,0,0]):
+ '''Load an image using PIL and convert it into specified color space,
+ and return it as an numpy array.
+
+ https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
+ The code is modified from Keras.preprocessing.image.load_img, img_to_array.
+ '''
+ ## Load image
+ img = Image.open(path)
+ if color_mode == 'RGB':
+ cimg = img.convert('RGB')
+ x = np.asarray(cimg, dtype='float32')
+
+ elif color_mode == 'YCbCr' or color_mode == 'Y':
+ cimg = img.convert('YCbCr')
+ x = np.asarray(cimg, dtype='float32')
+ if color_mode == 'Y':
+ x = x[:,:,0:1]
+
+ ## To 0-1
+ x *= 1.0/255.0
+
+ if channel_mean:
+ x[:,:,0] -= channel_mean[0]
+ x[:,:,1] -= channel_mean[1]
+ x[:,:,2] -= channel_mean[2]
+
+ if modcrop[0]*modcrop[1]*modcrop[2]*modcrop[3]:
+ x = x[modcrop[0]:-modcrop[1], modcrop[2]:-modcrop[3], :]
+
+ return x
+
+def DownSample(x, h, scale=4):
+ ds_x = tf.shape(x)
+ x = tf.reshape(x, [ds_x[0]*ds_x[1], ds_x[2], ds_x[3], 3])
+
+ # Reflect padding
+ W = tf.constant(h)
+
+ filter_height, filter_width = 13, 13
+ pad_height = filter_height - 1
+ pad_width = filter_width - 1
+
+ # When pad_height (pad_width) is odd, we pad more to bottom (right),
+ # following the same convention as conv2d().
+ pad_top = pad_height // 2
+ pad_bottom = pad_height - pad_top
+ pad_left = pad_width // 2
+ pad_right = pad_width - pad_left
+ pad_array = [[0,0], [pad_top, pad_bottom], [pad_left, pad_right], [0,0]]
+
+ depthwise_F = tf.tile(W, [1, 1, 3, 1])
+ y = tf.nn.depthwise_conv2d(tf.pad(x, pad_array, mode='REFLECT'), depthwise_F, [1, scale, scale, 1], 'VALID')
+
+ ds_y = tf.shape(y)
+ y = tf.reshape(y, [ds_x[0], ds_x[1], ds_y[1], ds_y[2], 3])
+ return y
+
+def _rgb2ycbcr(img, maxVal=255):
+ O = np.array([[16],
+ [128],
+ [128]])
+ T = np.array([[0.256788235294118, 0.504129411764706, 0.097905882352941],
+ [-0.148223529411765, -0.290992156862745, 0.439215686274510],
+ [0.439215686274510, -0.367788235294118, -0.071427450980392]])
+
+ if maxVal == 1:
+ O = O / 255.0
+
+ t = np.reshape(img, (img.shape[0]*img.shape[1], img.shape[2]))
+ t = np.dot(t, np.transpose(T))
+ t[:, 0] += O[0]
+ t[:, 1] += O[1]
+ t[:, 2] += O[2]
+ ycbcr = np.reshape(t, [img.shape[0], img.shape[1], img.shape[2]])
+
+ return ycbcr
+
+def to_uint8(x, vmin, vmax):
+ x = x.astype('float32')
+ x = (x-vmin)/(vmax-vmin)*255 # 0~255
+ return np.clip(np.round(x), 0, 255)
+
+def AVG_PSNR(vid_true, vid_pred, vmin=0, vmax=255, t_border=2, sp_border=8, is_T_Y=False, is_P_Y=False):
+ '''
+ This include RGB2ycbcr and VPSNR computed in Y
+ '''
+ input_shape = vid_pred.shape
+ if is_T_Y:
+ Y_true = to_uint8(vid_true, vmin, vmax)
+ else:
+ Y_true = np.empty(input_shape[:-1])
+ for t in range(input_shape[0]):
+ Y_true[t] = _rgb2ycbcr(to_uint8(vid_true[t], vmin, vmax), 255)[:,:,0]
+
+ if is_P_Y:
+ Y_pred = to_uint8(vid_pred, vmin, vmax)
+ else:
+ Y_pred = np.empty(input_shape[:-1])
+ for t in range(input_shape[0]):
+ Y_pred[t] = _rgb2ycbcr(to_uint8(vid_pred[t], vmin, vmax), 255)[:,:,0]
+
+ diff = Y_true - Y_pred
+ diff = diff[t_border: input_shape[0]- t_border, sp_border: input_shape[1]- sp_border, sp_border: input_shape[2]- sp_border]
+
+ psnrs = []
+ for t in range(diff.shape[0]):
+ rmse = np.sqrt(np.mean(np.power(diff[t],2)))
+ psnrs.append(20*np.log10(255./rmse))
+
+ return np.mean(np.asarray(psnrs))
+
+
+he_normal_init = tf.contrib.layers.variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False)
+
+def BatchNorm(input, is_train, decay=0.999, name='BatchNorm'):
+ '''
+ https://github.com/zsdonghao/tensorlayer/blob/master/tensorlayer/layers.py
+ https://github.com/ry/tensorflow-resnet/blob/master/resnet.py
+ http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow
+ '''
+ from tensorflow.python.training import moving_averages
+ from tensorflow.python.ops import control_flow_ops
+
+ axis = list(range(len(input.get_shape()) - 1))
+ fdim = input.get_shape()[-1:]
+
+ with tf.variable_scope(name):
+ beta = tf.get_variable('beta', fdim, initializer=tf.constant_initializer(value=0.0))
+ gamma = tf.get_variable('gamma', fdim, initializer=tf.constant_initializer(value=1.0))
+ moving_mean = tf.get_variable('moving_mean', fdim, initializer=tf.constant_initializer(value=0.0), trainable=False)
+ moving_variance = tf.get_variable('moving_variance', fdim, initializer=tf.constant_initializer(value=0.0), trainable=False)
+
+ def mean_var_with_update():
+ batch_mean, batch_variance = tf.nn.moments(input, axis)
+ update_moving_mean = moving_averages.assign_moving_average(moving_mean, batch_mean, decay, zero_debias=True)
+ update_moving_variance = moving_averages.assign_moving_average(moving_variance, batch_variance, decay, zero_debias=True)
+ with tf.control_dependencies([update_moving_mean, update_moving_variance]):
+ return tf.identity(batch_mean), tf.identity(batch_variance)
+
+ mean, variance = control_flow_ops.cond(is_train, mean_var_with_update, lambda: (moving_mean, moving_variance))
+
+ return tf.nn.batch_normalization(input, mean, variance, beta, gamma, 1e-3) #, tf.stack([mean[0], variance[0], beta[0], gamma[0]])
+
+def Conv3D(input, kernel_shape, strides, padding, name='Conv3d', W_initializer=he_normal_init, bias=True):
+ with tf.variable_scope(name):
+ W = tf.get_variable("W", kernel_shape, initializer=W_initializer)
+ if bias is True:
+ b = tf.get_variable("b", (kernel_shape[-1]),initializer=tf.constant_initializer(value=0.0))
+ else:
+ b = 0
+
+ return tf.nn.conv3d(input, W, strides, padding) + b
+
+def LoadParams(sess, params, in_file='parmas.hdf5'):
+ f = h5py.File(in_file, 'r')
+ g = f['params']
+ assign_ops = []
+ # Flatten list
+ params = [item for sublist in params for item in sublist]
+
+ for param in params:
+ flag = False
+ for idx, name in enumerate(g):
+ #
+ parsed_name = list(name)
+ for i in range(0+1, len(parsed_name)-1):
+ if parsed_name[i] == '_' and (parsed_name[i-1] != '_' and parsed_name[i+1] != '_'):
+ parsed_name[i] = '/'
+ parsed_name = ''.join(parsed_name)
+ parsed_name = parsed_name.replace('__','_')
+
+ if param.name == parsed_name:
+ flag = True
+# print(param.name)
+ assign_ops += [param.assign(g[name][()])]
+
+ if not flag:
+ print('Warning::Cant find param: {}, ignore if intended.'.format(param.name))
+
+ sess.run(assign_ops)
+
+ print('Parameters are loaded')
+
+def depth_to_space_3D(x, block_size):
+ ds_x = tf.shape(x)
+ x = tf.reshape(x, [ds_x[0]*ds_x[1], ds_x[2], ds_x[3], ds_x[4]])
+
+ y = tf.depth_to_space(x, block_size)
+
+ ds_y = tf.shape(y)
+ x = tf.reshape(y, [ds_x[0], ds_x[1], ds_y[1], ds_y[2], ds_y[3]])
+ return x
+
+def DynFilter3D(x, F, filter_size):
+ '''
+ 3D Dynamic filtering
+ input x: (b, t, h, w)
+ F: (b, h, w, tower_depth, output_depth)
+ filter_shape (ft, fh, fw)
+ '''
+ # make tower
+ filter_localexpand_np = np.reshape(np.eye(np.prod(filter_size), np.prod(filter_size)), (filter_size[1], filter_size[2], filter_size[0], np.prod(filter_size)))
+ filter_localexpand = tf.Variable(filter_localexpand_np, trainable=False, dtype='float32',name='filter_localexpand')
+ x = tf.transpose(x, perm=[0,2,3,1])
+ x_localexpand = tf.nn.conv2d(x, filter_localexpand, [1,1,1,1], 'SAME') # b, h, w, 1*5*5
+ x_localexpand = tf.expand_dims(x_localexpand, axis=3) # b, h, w, 1, 1*5*5
+ x = tf.matmul(x_localexpand, F) # b, h, w, 1, R*R
+ x = tf.squeeze(x, axis=3) # b, h, w, R*R
+
+ return x
+
+def Huber(y_true, y_pred, delta, axis=None):
+ abs_error = tf.abs(y_pred - y_true)
+ quadratic = tf.minimum(abs_error, delta)
+ # The following expression is the same in value as
+ # tf.maximum(abs_error - delta, 0), but importantly the gradient for the
+ # expression when abs_error == delta is 0 (for tf.maximum it would be 1).
+ # This is necessary to avoid doubling the gradient, since there is already a
+ # nonzero contribution to the gradient from the quadratic term.
+ linear = (abs_error - quadratic)
+ losses = 0.5 * quadratic**2 + delta * linear
+ return tf.reduce_mean(losses, axis=axis)