summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--README.md44
-rw-r--r--SeparableConvolution.py48
-rw-r--r--SeparableConvolution.pycbin0 -> 1920 bytes
-rw-r--r--images/README.md1
-rw-r--r--images/first.pngbin0 -> 327694 bytes
-rw-r--r--images/second.pngbin0 -> 328562 bytes
-rw-r--r--install.bash10
-rw-r--r--install.py33
-rw-r--r--run.py232
-rw-r--r--src/SeparableConvolution_cuda.c23
-rw-r--r--src/SeparableConvolution_cuda.h6
-rw-r--r--src/SeparableConvolution_kernel.cu70
-rw-r--r--src/SeparableConvolution_kernel.h15
-rw-r--r--src/SeparableConvolution_kernel.obin0 -> 10984 bytes
15 files changed, 483 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..d8fe4fa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+/.project
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..883c040
--- /dev/null
+++ b/README.md
@@ -0,0 +1,44 @@
+# pytorch-sepconv
+This is a reference implementation of Video Frame Interpolation via Adaptive Separable Convolution [1] using PyTorch. Given two frames, it will make use of <a href="http://graphics.cs.pdx.edu/project/adaconv">adaptive convolution</a> [2] in a separable manner to interpolate the intermediate frame. Should you be making use of our work, please cite our paper [1].
+
+<a href="https://arxiv.org/abs/1703.07514" rel="Paper"><img src="http://content.sniklaus.com/SepConv/Paper.jpg" alt="Paper" width="100%"></a>
+
+For the Torch version of this work, please see: https://github.com/sniklaus/torch-sepconv
+
+## setup
+To build the implementation and download the pretrained networks, run `bash install.bash` and make sure that you configured the `CUDA_HOME` environment variable. After successfully completeing this step, run `python run.py` to test it. Should you receive an error message regarding an invalid device function during execution, configure the utilized CUDA architecture within `install.bash` to something your graphics card supports.
+
+## usage
+To run it on your own pair of frames, use the following command. You can either select the `l1` or the `lf` model, please see our paper for more details.
+
+```
+python run.py --model lf --first ./images/first.png --second ./images/second.png --out ./result.png
+```
+
+## video
+<a href="http://web.cecs.pdx.edu/~fliu/project/sepconv/demo.mp4" rel="Video"><img src="http://web.cecs.pdx.edu/~fliu/project/sepconv/screen.jpg" alt="Video" width="100%"></a>
+
+## license
+The provided implementation is strictly for academic purposes only. Should you be interested in using our intelectual property, please feel free to contact us.
+
+## references
+```
+[1] @inproceedings{Niklaus_ICCV_2017,
+ author = {Simon Niklaus and Long Mai and Feng Liu},
+ title = {Video Frame Interpolation via Adaptive Separable Convolution},
+ booktitle = {IEEE International Conference on Computer Vision},
+ year = {2017}
+ }
+```
+
+```
+[2] @inproceedings{Niklaus_CVPR_2017,
+ author = {Simon Niklaus and Long Mai and Feng Liu},
+ title = {Video Frame Interpolation via Adaptive Convolution},
+ booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
+ year = {2017}
+ }
+```
+
+## acknowledgment
+This work was supported by NSF IIS-1321119. This video uses materials under a Creative Common license or with the owner's permission, as detailed at the end. \ No newline at end of file
diff --git a/SeparableConvolution.py b/SeparableConvolution.py
new file mode 100644
index 0000000..9e5d1e9
--- /dev/null
+++ b/SeparableConvolution.py
@@ -0,0 +1,48 @@
+import torch
+
+import _ext.cunnex
+
+class SeparableConvolution(torch.autograd.Function):
+ def __init__(self):
+ super(SeparableConvolution, self).__init__()
+ # end
+
+ def forward(self, input1, input2, input3):
+ intBatches = input1.size(0)
+ intInputDepth = input1.size(1)
+ intInputHeight = input1.size(2)
+ intInputWidth = input1.size(3)
+ intFilterSize = min(input2.size(1), input3.size(1))
+ intOutputHeight = min(input2.size(2), input3.size(2))
+ intOutputWidth = min(input2.size(3), input3.size(3))
+
+ assert(intInputHeight - 51 == intOutputHeight - 1)
+ assert(intInputWidth - 51 == intOutputWidth - 1)
+ assert(intFilterSize == 51)
+
+ assert(input1.is_contiguous() == True)
+ assert(input2.is_contiguous() == True)
+ assert(input3.is_contiguous() == True)
+
+ output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_()
+
+ if input1.is_cuda == True:
+ _ext.cunnex.SeparableConvolution_cuda_forward(
+ input1,
+ input2,
+ input3,
+ output
+ )
+
+ elif input1.is_cuda == False:
+ assert(False) # NOT IMPLEMENTED
+
+ # end
+
+ return output
+ # end
+
+ def backward(self, gradOutput):
+ assert(false) # NOT IMPLEMENTED
+ # end
+# end \ No newline at end of file
diff --git a/SeparableConvolution.pyc b/SeparableConvolution.pyc
new file mode 100644
index 0000000..b1e4d60
--- /dev/null
+++ b/SeparableConvolution.pyc
Binary files differ
diff --git a/images/README.md b/images/README.md
new file mode 100644
index 0000000..e9b6b81
--- /dev/null
+++ b/images/README.md
@@ -0,0 +1 @@
+The used example originates from the Middlebury benchmark for Optical Flow: http://vision.middlebury.edu/flow \ No newline at end of file
diff --git a/images/first.png b/images/first.png
new file mode 100644
index 0000000..6aa7501
--- /dev/null
+++ b/images/first.png
Binary files differ
diff --git a/images/second.png b/images/second.png
new file mode 100644
index 0000000..de55c16
--- /dev/null
+++ b/images/second.png
Binary files differ
diff --git a/install.bash b/install.bash
new file mode 100644
index 0000000..75c4b04
--- /dev/null
+++ b/install.bash
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))")
+
+nvcc -c -o src/SeparableConvolution_kernel.o src/SeparableConvolution_kernel.cu --gpu-architecture=compute_52 --gpu-code=compute_52 --compiler-options -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC
+
+python install.py
+
+wget --timestamping http://content.sniklaus.com/SepConv/network-l1.pytorch
+wget --timestamping http://content.sniklaus.com/SepConv/network-lf.pytorch \ No newline at end of file
diff --git a/install.py b/install.py
new file mode 100644
index 0000000..42847fa
--- /dev/null
+++ b/install.py
@@ -0,0 +1,33 @@
+import os
+import torch
+import torch.utils.ffi
+
+strBasepath = os.path.split(os.path.abspath(__file__))[0] + '/'
+strHeaders = []
+strSources = []
+strDefines = []
+strObjects = []
+
+if torch.cuda.is_available() == True:
+ strHeaders += ['src/SeparableConvolution_cuda.h']
+ strSources += ['src/SeparableConvolution_cuda.c']
+ strDefines += [('WITH_CUDA', None)]
+ strObjects += ['src/SeparableConvolution_kernel.o']
+# end
+
+objectExtension = torch.utils.ffi.create_extension(
+ name='_ext.cunnex',
+ headers=strHeaders,
+ sources=strSources,
+ verbose=False,
+ with_cuda=any(strDefine[0] == 'WITH_CUDA' for strDefine in strDefines),
+ package=False,
+ relative_to=strBasepath,
+ include_dirs=[os.path.expandvars('$CUDA_HOME') + '/include'],
+ define_macros=strDefines,
+ extra_objects=[os.path.join(strBasepath, strObject) for strObject in strObjects]
+)
+
+if __name__ == '__main__':
+ objectExtension.build()
+# end \ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000..7c2bd54
--- /dev/null
+++ b/run.py
@@ -0,0 +1,232 @@
+#!/usr/bin/env python2.7
+
+import sys
+import getopt
+import math
+import numpy
+import torch
+import torch.utils.serialization
+import PIL
+import PIL.Image
+
+from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer
+
+torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them
+
+torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
+
+##########################################################
+
+arguments_strModel = 'lf'
+arguments_strFirst = './images/first.png'
+arguments_strSecond = './images/second.png'
+arguments_strOut = './result.png'
+
+for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
+ if strOption == '--model':
+ arguments_strModel = strArgument # which model to use, l1 or lf, please see our paper for more details
+
+ elif strOption == '--first':
+ arguments_strFirst = strArgument # path to the first frame
+
+ elif strOption == '--second':
+ arguments_strSecond = strArgument # path to the first frame
+
+ elif strOption == '--out':
+ arguments_strOut = strArgument # path to where the output should be stored
+
+ # end
+# end
+
+##########################################################
+
+class Network(torch.nn.Module):
+ def __init__(self):
+ super(Network, self).__init__()
+
+ def Basic(intInput, intOutput):
+ return torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+ # end
+
+ def Subnet():
+ return torch.nn.Sequential(
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False),
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1)
+ )
+ # end
+
+ self.moduleConv1 = Basic(6, 32)
+ self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv2 = Basic(32, 64)
+ self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv3 = Basic(64, 128)
+ self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv4 = Basic(128, 256)
+ self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleConv5 = Basic(256, 512)
+ self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2)
+
+ self.moduleDeconv5 = Basic(512, 512)
+ self.moduleUpsample5 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv4 = Basic(512, 256)
+ self.moduleUpsample4 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv3 = Basic(256, 128)
+ self.moduleUpsample3 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleDeconv2 = Basic(128, 64)
+ self.moduleUpsample2 = torch.nn.Sequential(
+ torch.nn.Upsample(scale_factor=2, mode='bilinear'),
+ torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
+ torch.nn.ReLU(inplace=False)
+ )
+
+ self.moduleVertical1 = Subnet()
+ self.moduleVertical2 = Subnet()
+ self.moduleHorizontal1 = Subnet()
+ self.moduleHorizontal2 = Subnet()
+
+ self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ])
+
+ self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch'))
+ # end
+
+ def forward(self, variableInput1, variableInput2):
+ variableJoin = torch.cat([variableInput1, variableInput2], 1)
+
+ variableConv1 = self.moduleConv1(variableJoin)
+ variablePool1 = self.modulePool1(variableConv1)
+
+ variableConv2 = self.moduleConv2(variablePool1)
+ variablePool2 = self.modulePool2(variableConv2)
+
+ variableConv3 = self.moduleConv3(variablePool2)
+ variablePool3 = self.modulePool3(variableConv3)
+
+ variableConv4 = self.moduleConv4(variablePool3)
+ variablePool4 = self.modulePool4(variableConv4)
+
+ variableConv5 = self.moduleConv5(variablePool4)
+ variablePool5 = self.modulePool5(variableConv5)
+
+ variableDeconv5 = self.moduleDeconv5(variablePool5)
+ variableUpsample5 = self.moduleUpsample5(variableDeconv5)
+
+ variableDeconv4 = self.moduleDeconv4(variableUpsample5 + variableConv5)
+ variableUpsample4 = self.moduleUpsample4(variableDeconv4)
+
+ variableDeconv3 = self.moduleDeconv3(variableUpsample4 + variableConv4)
+ variableUpsample3 = self.moduleUpsample3(variableDeconv3)
+
+ variableDeconv2 = self.moduleDeconv2(variableUpsample3 + variableConv3)
+ variableUpsample2 = self.moduleUpsample2(variableDeconv2)
+
+ variableCombine = variableUpsample2 + variableConv2
+
+ variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine))
+ variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine))
+
+ return variableDot1 + variableDot2
+ # end
+# end
+
+moduleNetwork = Network().cuda()
+
+##########################################################
+
+tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0)
+tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0)
+tensorOutput = torch.FloatTensor()
+
+assert(tensorInputFirst.size(1) == tensorInputSecond.size(1))
+assert(tensorInputFirst.size(2) == tensorInputSecond.size(2))
+
+intWidth = tensorInputFirst.size(2)
+intHeight = tensorInputFirst.size(1)
+
+assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
+assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications
+
+intPaddingLeft = int(math.floor(51 / 2.0))
+intPaddingTop = int(math.floor(51 / 2.0))
+intPaddingRight = int(math.floor(51 / 2.0))
+intPaddingBottom = int(math.floor(51 / 2.0))
+modulePaddingFirst = torch.nn.Module()
+modulePaddingSecond = torch.nn.Module()
+modulePaddingOutput = torch.nn.Module()
+
+if True:
+ intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight
+ intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom
+
+ if intPaddingWidth != ((intPaddingWidth >> 7) << 7):
+ intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary
+ # end
+
+ if intPaddingHeight != ((intPaddingHeight >> 7) << 7):
+ intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary
+ # end
+
+ intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight)
+ intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom)
+
+ modulePaddingFirst = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight])
+ modulePaddingSecond = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight])
+ modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight])
+
+ modulePaddingFirst = modulePaddingFirst.cuda()
+ modulePaddingSecond = modulePaddingSecond.cuda()
+ modulePaddingOutput = modulePaddingOutput.cuda()
+# end
+
+if True:
+ tensorInputFirst = tensorInputFirst.cuda()
+ tensorInputSecond = tensorInputSecond.cuda()
+ tensorOutput = tensorOutput.cuda()
+# end
+
+if True:
+ variablePaddingFirst = modulePaddingFirst(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True))
+ variablePaddingSecond = modulePaddingSecond(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True))
+ variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond))
+
+ tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0])
+# end
+
+if True:
+ tensorInputFirst = tensorInputFirst.cpu()
+ tensorInputSecond = tensorInputSecond.cpu()
+ tensorOutput = tensorOutput.cpu()
+# end
+
+PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut) \ No newline at end of file
diff --git a/src/SeparableConvolution_cuda.c b/src/SeparableConvolution_cuda.c
new file mode 100644
index 0000000..a48622e
--- /dev/null
+++ b/src/SeparableConvolution_cuda.c
@@ -0,0 +1,23 @@
+#include <THC.h>
+#include <THCGeneral.h>
+
+#include "SeparableConvolution_kernel.h"
+
+extern THCState* state;
+
+int SeparableConvolution_cuda_forward(
+ THCudaTensor* input1,
+ THCudaTensor* input2,
+ THCudaTensor* input3,
+ THCudaTensor* output
+) {
+ SeparableConvolution_kernel_forward(
+ state,
+ input1,
+ input2,
+ input3,
+ output
+ );
+
+ return 1;
+} \ No newline at end of file
diff --git a/src/SeparableConvolution_cuda.h b/src/SeparableConvolution_cuda.h
new file mode 100644
index 0000000..7320fee
--- /dev/null
+++ b/src/SeparableConvolution_cuda.h
@@ -0,0 +1,6 @@
+int SeparableConvolution_cuda_forward(
+ THCudaTensor* input1,
+ THCudaTensor* input2,
+ THCudaTensor* input3,
+ THCudaTensor* output
+); \ No newline at end of file
diff --git a/src/SeparableConvolution_kernel.cu b/src/SeparableConvolution_kernel.cu
new file mode 100644
index 0000000..b40786d
--- /dev/null
+++ b/src/SeparableConvolution_kernel.cu
@@ -0,0 +1,70 @@
+#include <THC.h>
+#include <THCGeneral.h>
+
+#define VEC_0(ARRAY) ((ARRAY).x)
+#define VEC_1(ARRAY) ((ARRAY).y)
+#define VEC_2(ARRAY) ((ARRAY).z)
+#define VEC_3(ARRAY) ((ARRAY).w)
+
+#define IDX_1(ARRAY, X) ((ARRAY)[((X) * (ARRAY##_stride.x))])
+#define IDX_2(ARRAY, X, Y) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y))])
+#define IDX_3(ARRAY, X, Y, Z) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y)) + ((Z) * (ARRAY##_stride.z))])
+#define IDX_4(ARRAY, X, Y, Z, W) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y)) + ((Z) * (ARRAY##_stride.z)) + ((W) * (ARRAY##_stride.w))])
+
+#ifdef __cplusplus
+ extern "C" {
+#endif
+
+__global__ void kernel_SeparableConvolution_updateOutput(
+ const int n,
+ const float* input1, const long4 input1_size, const long4 input1_stride,
+ const float* input2, const long4 input2_size, const long4 input2_stride,
+ const float* input3, const long4 input3_size, const long4 input3_stride,
+ float* output, const long4 output_size, const long4 output_stride
+) {
+ int intIndex = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (intIndex >= n) {
+ return;
+ }
+
+ float dblOutput = 0.0;
+
+ int intBatch = ( intIndex / VEC_3(output_size) / VEC_2(output_size) / VEC_1(output_size) ) % VEC_0(output_size);
+ int intDepth = ( intIndex / VEC_3(output_size) / VEC_2(output_size) ) % VEC_1(output_size);
+ int intY = ( intIndex / VEC_3(output_size) ) % VEC_2(output_size);
+ int intX = ( intIndex ) % VEC_3(output_size);
+
+ for (int intFilterY = 0; intFilterY < 51; intFilterY += 1) {
+ for (int intFilterX = 0; intFilterX < 51; intFilterX += 1) {
+ dblOutput += IDX_4(input1, intBatch, intDepth, intY + intFilterY, intX + intFilterX) * IDX_4(input2, intBatch, intFilterY, intY, intX) * IDX_4(input3, intBatch, intFilterX, intY, intX);
+ }
+ }
+
+ output[intIndex] = dblOutput;
+}
+
+void SeparableConvolution_kernel_forward(
+ THCState* state,
+ THCudaTensor* input1,
+ THCudaTensor* input2,
+ THCudaTensor* input3,
+ THCudaTensor* output
+) {
+ int n = 0;
+
+ n = THCudaTensor_nElement(state, output);
+ kernel_SeparableConvolution_updateOutput<<< (n + 512 - 1) / 512, 512, 0, THCState_getCurrentStream(state) >>>(
+ n,
+ THCudaTensor_data(state, input1), make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]), make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]),
+ THCudaTensor_data(state, input2), make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]), make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]),
+ THCudaTensor_data(state, input3), make_long4(input3->size[0], input3->size[1], input3->size[2], input3->size[3]), make_long4(input3->stride[0], input3->stride[1], input3->stride[2], input3->stride[3]),
+ THCudaTensor_data(state, output), make_long4(output->size[0], output->size[1], output->size[2], output->size[3]), make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3])
+ );
+
+ THCudaCheck(cudaGetLastError());
+}
+
+#ifdef __cplusplus
+ }
+#endif \ No newline at end of file
diff --git a/src/SeparableConvolution_kernel.h b/src/SeparableConvolution_kernel.h
new file mode 100644
index 0000000..b400579
--- /dev/null
+++ b/src/SeparableConvolution_kernel.h
@@ -0,0 +1,15 @@
+#ifdef __cplusplus
+ extern "C" {
+#endif
+
+void SeparableConvolution_kernel_forward(
+ THCState* state,
+ THCudaTensor* input1,
+ THCudaTensor* input2,
+ THCudaTensor* input3,
+ THCudaTensor* output
+);
+
+#ifdef __cplusplus
+ }
+#endif \ No newline at end of file
diff --git a/src/SeparableConvolution_kernel.o b/src/SeparableConvolution_kernel.o
new file mode 100644
index 0000000..396aba1
--- /dev/null
+++ b/src/SeparableConvolution_kernel.o
Binary files differ