diff options
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | README.md | 44 | ||||
| -rw-r--r-- | SeparableConvolution.py | 48 | ||||
| -rw-r--r-- | SeparableConvolution.pyc | bin | 0 -> 1920 bytes | |||
| -rw-r--r-- | images/README.md | 1 | ||||
| -rw-r--r-- | images/first.png | bin | 0 -> 327694 bytes | |||
| -rw-r--r-- | images/second.png | bin | 0 -> 328562 bytes | |||
| -rw-r--r-- | install.bash | 10 | ||||
| -rw-r--r-- | install.py | 33 | ||||
| -rw-r--r-- | run.py | 232 | ||||
| -rw-r--r-- | src/SeparableConvolution_cuda.c | 23 | ||||
| -rw-r--r-- | src/SeparableConvolution_cuda.h | 6 | ||||
| -rw-r--r-- | src/SeparableConvolution_kernel.cu | 70 | ||||
| -rw-r--r-- | src/SeparableConvolution_kernel.h | 15 | ||||
| -rw-r--r-- | src/SeparableConvolution_kernel.o | bin | 0 -> 10984 bytes |
15 files changed, 483 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d8fe4fa --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.project diff --git a/README.md b/README.md new file mode 100644 index 0000000..883c040 --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +# pytorch-sepconv +This is a reference implementation of Video Frame Interpolation via Adaptive Separable Convolution [1] using PyTorch. Given two frames, it will make use of <a href="http://graphics.cs.pdx.edu/project/adaconv">adaptive convolution</a> [2] in a separable manner to interpolate the intermediate frame. Should you be making use of our work, please cite our paper [1]. + +<a href="https://arxiv.org/abs/1703.07514" rel="Paper"><img src="http://content.sniklaus.com/SepConv/Paper.jpg" alt="Paper" width="100%"></a> + +For the Torch version of this work, please see: https://github.com/sniklaus/torch-sepconv + +## setup +To build the implementation and download the pretrained networks, run `bash install.bash` and make sure that you configured the `CUDA_HOME` environment variable. After successfully completeing this step, run `python run.py` to test it. Should you receive an error message regarding an invalid device function during execution, configure the utilized CUDA architecture within `install.bash` to something your graphics card supports. + +## usage +To run it on your own pair of frames, use the following command. You can either select the `l1` or the `lf` model, please see our paper for more details. + +``` +python run.py --model lf --first ./images/first.png --second ./images/second.png --out ./result.png +``` + +## video +<a href="http://web.cecs.pdx.edu/~fliu/project/sepconv/demo.mp4" rel="Video"><img src="http://web.cecs.pdx.edu/~fliu/project/sepconv/screen.jpg" alt="Video" width="100%"></a> + +## license +The provided implementation is strictly for academic purposes only. Should you be interested in using our intelectual property, please feel free to contact us. + +## references +``` +[1] @inproceedings{Niklaus_ICCV_2017, + author = {Simon Niklaus and Long Mai and Feng Liu}, + title = {Video Frame Interpolation via Adaptive Separable Convolution}, + booktitle = {IEEE International Conference on Computer Vision}, + year = {2017} + } +``` + +``` +[2] @inproceedings{Niklaus_CVPR_2017, + author = {Simon Niklaus and Long Mai and Feng Liu}, + title = {Video Frame Interpolation via Adaptive Convolution}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, + year = {2017} + } +``` + +## acknowledgment +This work was supported by NSF IIS-1321119. This video uses materials under a Creative Common license or with the owner's permission, as detailed at the end.
\ No newline at end of file diff --git a/SeparableConvolution.py b/SeparableConvolution.py new file mode 100644 index 0000000..9e5d1e9 --- /dev/null +++ b/SeparableConvolution.py @@ -0,0 +1,48 @@ +import torch + +import _ext.cunnex + +class SeparableConvolution(torch.autograd.Function): + def __init__(self): + super(SeparableConvolution, self).__init__() + # end + + def forward(self, input1, input2, input3): + intBatches = input1.size(0) + intInputDepth = input1.size(1) + intInputHeight = input1.size(2) + intInputWidth = input1.size(3) + intFilterSize = min(input2.size(1), input3.size(1)) + intOutputHeight = min(input2.size(2), input3.size(2)) + intOutputWidth = min(input2.size(3), input3.size(3)) + + assert(intInputHeight - 51 == intOutputHeight - 1) + assert(intInputWidth - 51 == intOutputWidth - 1) + assert(intFilterSize == 51) + + assert(input1.is_contiguous() == True) + assert(input2.is_contiguous() == True) + assert(input3.is_contiguous() == True) + + output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() + + if input1.is_cuda == True: + _ext.cunnex.SeparableConvolution_cuda_forward( + input1, + input2, + input3, + output + ) + + elif input1.is_cuda == False: + assert(False) # NOT IMPLEMENTED + + # end + + return output + # end + + def backward(self, gradOutput): + assert(false) # NOT IMPLEMENTED + # end +# end
\ No newline at end of file diff --git a/SeparableConvolution.pyc b/SeparableConvolution.pyc Binary files differnew file mode 100644 index 0000000..b1e4d60 --- /dev/null +++ b/SeparableConvolution.pyc diff --git a/images/README.md b/images/README.md new file mode 100644 index 0000000..e9b6b81 --- /dev/null +++ b/images/README.md @@ -0,0 +1 @@ +The used example originates from the Middlebury benchmark for Optical Flow: http://vision.middlebury.edu/flow
\ No newline at end of file diff --git a/images/first.png b/images/first.png Binary files differnew file mode 100644 index 0000000..6aa7501 --- /dev/null +++ b/images/first.png diff --git a/images/second.png b/images/second.png Binary files differnew file mode 100644 index 0000000..de55c16 --- /dev/null +++ b/images/second.png diff --git a/install.bash b/install.bash new file mode 100644 index 0000000..75c4b04 --- /dev/null +++ b/install.bash @@ -0,0 +1,10 @@ +#!/bin/bash + +TORCH=$(python -c "import os; import torch; print(os.path.dirname(torch.__file__))") + +nvcc -c -o src/SeparableConvolution_kernel.o src/SeparableConvolution_kernel.cu --gpu-architecture=compute_52 --gpu-code=compute_52 --compiler-options -fPIC -I ${TORCH}/lib/include/TH -I ${TORCH}/lib/include/THC + +python install.py + +wget --timestamping http://content.sniklaus.com/SepConv/network-l1.pytorch +wget --timestamping http://content.sniklaus.com/SepConv/network-lf.pytorch
\ No newline at end of file diff --git a/install.py b/install.py new file mode 100644 index 0000000..42847fa --- /dev/null +++ b/install.py @@ -0,0 +1,33 @@ +import os +import torch +import torch.utils.ffi + +strBasepath = os.path.split(os.path.abspath(__file__))[0] + '/' +strHeaders = [] +strSources = [] +strDefines = [] +strObjects = [] + +if torch.cuda.is_available() == True: + strHeaders += ['src/SeparableConvolution_cuda.h'] + strSources += ['src/SeparableConvolution_cuda.c'] + strDefines += [('WITH_CUDA', None)] + strObjects += ['src/SeparableConvolution_kernel.o'] +# end + +objectExtension = torch.utils.ffi.create_extension( + name='_ext.cunnex', + headers=strHeaders, + sources=strSources, + verbose=False, + with_cuda=any(strDefine[0] == 'WITH_CUDA' for strDefine in strDefines), + package=False, + relative_to=strBasepath, + include_dirs=[os.path.expandvars('$CUDA_HOME') + '/include'], + define_macros=strDefines, + extra_objects=[os.path.join(strBasepath, strObject) for strObject in strObjects] +) + +if __name__ == '__main__': + objectExtension.build() +# end
\ No newline at end of file @@ -0,0 +1,232 @@ +#!/usr/bin/env python2.7 + +import sys +import getopt +import math +import numpy +import torch +import torch.utils.serialization +import PIL +import PIL.Image + +from SeparableConvolution import SeparableConvolution # the custom SeparableConvolution layer + +torch.cuda.device(1) # change this if you have a multiple graphics cards and you want to utilize them + +torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + +########################################################## + +arguments_strModel = 'lf' +arguments_strFirst = './images/first.png' +arguments_strSecond = './images/second.png' +arguments_strOut = './result.png' + +for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: + if strOption == '--model': + arguments_strModel = strArgument # which model to use, l1 or lf, please see our paper for more details + + elif strOption == '--first': + arguments_strFirst = strArgument # path to the first frame + + elif strOption == '--second': + arguments_strSecond = strArgument # path to the first frame + + elif strOption == '--out': + arguments_strOut = strArgument # path to where the output should be stored + + # end +# end + +########################################################## + +class Network(torch.nn.Module): + def __init__(self): + super(Network, self).__init__() + + def Basic(intInput, intOutput): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + # end + + def Subnet(): + return torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Conv2d(in_channels=64, out_channels=51, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False), + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=51, out_channels=51, kernel_size=3, stride=1, padding=1) + ) + # end + + self.moduleConv1 = Basic(6, 32) + self.modulePool1 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv2 = Basic(32, 64) + self.modulePool2 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv3 = Basic(64, 128) + self.modulePool3 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv4 = Basic(128, 256) + self.modulePool4 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleConv5 = Basic(256, 512) + self.modulePool5 = torch.nn.AvgPool2d(kernel_size=2, stride=2) + + self.moduleDeconv5 = Basic(512, 512) + self.moduleUpsample5 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv4 = Basic(512, 256) + self.moduleUpsample4 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv3 = Basic(256, 128) + self.moduleUpsample3 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleDeconv2 = Basic(128, 64) + self.moduleUpsample2 = torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode='bilinear'), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.ReLU(inplace=False) + ) + + self.moduleVertical1 = Subnet() + self.moduleVertical2 = Subnet() + self.moduleHorizontal1 = Subnet() + self.moduleHorizontal2 = Subnet() + + self.modulePad = torch.nn.ReplicationPad2d([ int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)), int(math.floor(51 / 2.0)) ]) + + self.load_state_dict(torch.load('./network-' + arguments_strModel + '.pytorch')) + # end + + def forward(self, variableInput1, variableInput2): + variableJoin = torch.cat([variableInput1, variableInput2], 1) + + variableConv1 = self.moduleConv1(variableJoin) + variablePool1 = self.modulePool1(variableConv1) + + variableConv2 = self.moduleConv2(variablePool1) + variablePool2 = self.modulePool2(variableConv2) + + variableConv3 = self.moduleConv3(variablePool2) + variablePool3 = self.modulePool3(variableConv3) + + variableConv4 = self.moduleConv4(variablePool3) + variablePool4 = self.modulePool4(variableConv4) + + variableConv5 = self.moduleConv5(variablePool4) + variablePool5 = self.modulePool5(variableConv5) + + variableDeconv5 = self.moduleDeconv5(variablePool5) + variableUpsample5 = self.moduleUpsample5(variableDeconv5) + + variableDeconv4 = self.moduleDeconv4(variableUpsample5 + variableConv5) + variableUpsample4 = self.moduleUpsample4(variableDeconv4) + + variableDeconv3 = self.moduleDeconv3(variableUpsample4 + variableConv4) + variableUpsample3 = self.moduleUpsample3(variableDeconv3) + + variableDeconv2 = self.moduleDeconv2(variableUpsample3 + variableConv3) + variableUpsample2 = self.moduleUpsample2(variableDeconv2) + + variableCombine = variableUpsample2 + variableConv2 + + variableDot1 = SeparableConvolution()(self.modulePad(variableInput1), self.moduleVertical1(variableCombine), self.moduleHorizontal1(variableCombine)) + variableDot2 = SeparableConvolution()(self.modulePad(variableInput2), self.moduleVertical2(variableCombine), self.moduleHorizontal2(variableCombine)) + + return variableDot1 + variableDot2 + # end +# end + +moduleNetwork = Network().cuda() + +########################################################## + +tensorInputFirst = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strFirst))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) +tensorInputSecond = torch.FloatTensor(numpy.rollaxis(numpy.asarray(PIL.Image.open(arguments_strSecond))[:,:,::-1], 2, 0).astype(numpy.float32) / 255.0) +tensorOutput = torch.FloatTensor() + +assert(tensorInputFirst.size(1) == tensorInputSecond.size(1)) +assert(tensorInputFirst.size(2) == tensorInputSecond.size(2)) + +intWidth = tensorInputFirst.size(2) +intHeight = tensorInputFirst.size(1) + +assert(intWidth <= 1280) # while our approach works with larger images, we do not recommend it unless you are aware of the implications +assert(intHeight <= 720) # while our approach works with larger images, we do not recommend it unless you are aware of the implications + +intPaddingLeft = int(math.floor(51 / 2.0)) +intPaddingTop = int(math.floor(51 / 2.0)) +intPaddingRight = int(math.floor(51 / 2.0)) +intPaddingBottom = int(math.floor(51 / 2.0)) +modulePaddingFirst = torch.nn.Module() +modulePaddingSecond = torch.nn.Module() +modulePaddingOutput = torch.nn.Module() + +if True: + intPaddingWidth = intPaddingLeft + intWidth + intPaddingRight + intPaddingHeight = intPaddingTop + intHeight + intPaddingBottom + + if intPaddingWidth != ((intPaddingWidth >> 7) << 7): + intPaddingWidth = (((intPaddingWidth >> 7) + 1) << 7) # more than necessary + # end + + if intPaddingHeight != ((intPaddingHeight >> 7) << 7): + intPaddingHeight = (((intPaddingHeight >> 7) + 1) << 7) # more than necessary + # end + + intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight) + intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom) + + modulePaddingFirst = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) + modulePaddingSecond = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) + modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight]) + + modulePaddingFirst = modulePaddingFirst.cuda() + modulePaddingSecond = modulePaddingSecond.cuda() + modulePaddingOutput = modulePaddingOutput.cuda() +# end + +if True: + tensorInputFirst = tensorInputFirst.cuda() + tensorInputSecond = tensorInputSecond.cuda() + tensorOutput = tensorOutput.cuda() +# end + +if True: + variablePaddingFirst = modulePaddingFirst(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingSecond = modulePaddingSecond(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond)) + + tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0]) +# end + +if True: + tensorInputFirst = tensorInputFirst.cpu() + tensorInputSecond = tensorInputSecond.cpu() + tensorOutput = tensorOutput.cpu() +# end + +PIL.Image.fromarray((numpy.rollaxis(tensorOutput.clamp(0.0, 1.0).numpy(), 0, 3)[:,:,::-1] * 255.0).astype(numpy.uint8)).save(arguments_strOut)
\ No newline at end of file diff --git a/src/SeparableConvolution_cuda.c b/src/SeparableConvolution_cuda.c new file mode 100644 index 0000000..a48622e --- /dev/null +++ b/src/SeparableConvolution_cuda.c @@ -0,0 +1,23 @@ +#include <THC.h> +#include <THCGeneral.h> + +#include "SeparableConvolution_kernel.h" + +extern THCState* state; + +int SeparableConvolution_cuda_forward( + THCudaTensor* input1, + THCudaTensor* input2, + THCudaTensor* input3, + THCudaTensor* output +) { + SeparableConvolution_kernel_forward( + state, + input1, + input2, + input3, + output + ); + + return 1; +}
\ No newline at end of file diff --git a/src/SeparableConvolution_cuda.h b/src/SeparableConvolution_cuda.h new file mode 100644 index 0000000..7320fee --- /dev/null +++ b/src/SeparableConvolution_cuda.h @@ -0,0 +1,6 @@ +int SeparableConvolution_cuda_forward( + THCudaTensor* input1, + THCudaTensor* input2, + THCudaTensor* input3, + THCudaTensor* output +);
\ No newline at end of file diff --git a/src/SeparableConvolution_kernel.cu b/src/SeparableConvolution_kernel.cu new file mode 100644 index 0000000..b40786d --- /dev/null +++ b/src/SeparableConvolution_kernel.cu @@ -0,0 +1,70 @@ +#include <THC.h> +#include <THCGeneral.h> + +#define VEC_0(ARRAY) ((ARRAY).x) +#define VEC_1(ARRAY) ((ARRAY).y) +#define VEC_2(ARRAY) ((ARRAY).z) +#define VEC_3(ARRAY) ((ARRAY).w) + +#define IDX_1(ARRAY, X) ((ARRAY)[((X) * (ARRAY##_stride.x))]) +#define IDX_2(ARRAY, X, Y) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y))]) +#define IDX_3(ARRAY, X, Y, Z) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y)) + ((Z) * (ARRAY##_stride.z))]) +#define IDX_4(ARRAY, X, Y, Z, W) ((ARRAY)[((X) * (ARRAY##_stride.x)) + ((Y) * (ARRAY##_stride.y)) + ((Z) * (ARRAY##_stride.z)) + ((W) * (ARRAY##_stride.w))]) + +#ifdef __cplusplus + extern "C" { +#endif + +__global__ void kernel_SeparableConvolution_updateOutput( + const int n, + const float* input1, const long4 input1_size, const long4 input1_stride, + const float* input2, const long4 input2_size, const long4 input2_stride, + const float* input3, const long4 input3_size, const long4 input3_stride, + float* output, const long4 output_size, const long4 output_stride +) { + int intIndex = blockIdx.x * blockDim.x + threadIdx.x; + + if (intIndex >= n) { + return; + } + + float dblOutput = 0.0; + + int intBatch = ( intIndex / VEC_3(output_size) / VEC_2(output_size) / VEC_1(output_size) ) % VEC_0(output_size); + int intDepth = ( intIndex / VEC_3(output_size) / VEC_2(output_size) ) % VEC_1(output_size); + int intY = ( intIndex / VEC_3(output_size) ) % VEC_2(output_size); + int intX = ( intIndex ) % VEC_3(output_size); + + for (int intFilterY = 0; intFilterY < 51; intFilterY += 1) { + for (int intFilterX = 0; intFilterX < 51; intFilterX += 1) { + dblOutput += IDX_4(input1, intBatch, intDepth, intY + intFilterY, intX + intFilterX) * IDX_4(input2, intBatch, intFilterY, intY, intX) * IDX_4(input3, intBatch, intFilterX, intY, intX); + } + } + + output[intIndex] = dblOutput; +} + +void SeparableConvolution_kernel_forward( + THCState* state, + THCudaTensor* input1, + THCudaTensor* input2, + THCudaTensor* input3, + THCudaTensor* output +) { + int n = 0; + + n = THCudaTensor_nElement(state, output); + kernel_SeparableConvolution_updateOutput<<< (n + 512 - 1) / 512, 512, 0, THCState_getCurrentStream(state) >>>( + n, + THCudaTensor_data(state, input1), make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]), make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]), + THCudaTensor_data(state, input2), make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]), make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]), + THCudaTensor_data(state, input3), make_long4(input3->size[0], input3->size[1], input3->size[2], input3->size[3]), make_long4(input3->stride[0], input3->stride[1], input3->stride[2], input3->stride[3]), + THCudaTensor_data(state, output), make_long4(output->size[0], output->size[1], output->size[2], output->size[3]), make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]) + ); + + THCudaCheck(cudaGetLastError()); +} + +#ifdef __cplusplus + } +#endif
\ No newline at end of file diff --git a/src/SeparableConvolution_kernel.h b/src/SeparableConvolution_kernel.h new file mode 100644 index 0000000..b400579 --- /dev/null +++ b/src/SeparableConvolution_kernel.h @@ -0,0 +1,15 @@ +#ifdef __cplusplus + extern "C" { +#endif + +void SeparableConvolution_kernel_forward( + THCState* state, + THCudaTensor* input1, + THCudaTensor* input2, + THCudaTensor* input3, + THCudaTensor* output +); + +#ifdef __cplusplus + } +#endif
\ No newline at end of file diff --git a/src/SeparableConvolution_kernel.o b/src/SeparableConvolution_kernel.o Binary files differnew file mode 100644 index 0000000..396aba1 --- /dev/null +++ b/src/SeparableConvolution_kernel.o |
