diff options
| author | sniklaus <simon.niklaus@outlook.com> | 2017-09-18 22:20:04 -0700 |
|---|---|---|
| committer | sniklaus <simon.niklaus@outlook.com> | 2017-09-18 22:20:04 -0700 |
| commit | cfd6a91a628c603eeeecf517340ac0474a126496 (patch) | |
| tree | 7c891bc8660904771bb612f56aca8f22760d0cdb | |
| parent | e123297d61dc9915b70060def498560ca5d3d073 (diff) | |
no message
| -rw-r--r-- | README.md | 2 | ||||
| -rw-r--r-- | SeparableConvolution.py | 34 | ||||
| -rw-r--r-- | run.py | 17 | ||||
| -rw-r--r-- | src/SeparableConvolution_cuda.c | 12 | ||||
| -rw-r--r-- | src/SeparableConvolution_cuda.h | 6 | ||||
| -rw-r--r-- | src/SeparableConvolution_kernel.cu | 20 | ||||
| -rw-r--r-- | src/SeparableConvolution_kernel.h | 6 |
7 files changed, 47 insertions, 50 deletions
@@ -19,7 +19,7 @@ python run.py --model lf --first ./images/first.png --second ./images/second.png <a href="http://web.cecs.pdx.edu/~fliu/project/sepconv/demo.mp4" rel="Video"><img src="http://web.cecs.pdx.edu/~fliu/project/sepconv/screen.jpg" alt="Video" width="100%"></a> ## license -The provided implementation is strictly for academic purposes only. Should you be interested in using our intellectual property, please feel free to contact us. +The provided implementation is strictly for academic purposes only. Should you be interested in using our technology for any commercial use, please feel free to contact us. ## references ``` diff --git a/SeparableConvolution.py b/SeparableConvolution.py index 2ac31f9..d97a87f 100644 --- a/SeparableConvolution.py +++ b/SeparableConvolution.py @@ -7,34 +7,34 @@ class SeparableConvolution(torch.autograd.Function): super(SeparableConvolution, self).__init__() # end - def forward(self, input1, input2, input3): - intBatches = input1.size(0) - intInputDepth = input1.size(1) - intInputHeight = input1.size(2) - intInputWidth = input1.size(3) - intFilterSize = min(input2.size(1), input3.size(1)) - intOutputHeight = min(input2.size(2), input3.size(2)) - intOutputWidth = min(input2.size(3), input3.size(3)) + def forward(self, input, vertical, horizontal): + intBatches = input.size(0) + intInputDepth = input.size(1) + intInputHeight = input.size(2) + intInputWidth = input.size(3) + intFilterSize = min(vertical.size(1), horizontal.size(1)) + intOutputHeight = min(vertical.size(2), horizontal.size(2)) + intOutputWidth = min(vertical.size(3), horizontal.size(3)) assert(intInputHeight - 51 == intOutputHeight - 1) assert(intInputWidth - 51 == intOutputWidth - 1) assert(intFilterSize == 51) - assert(input1.is_contiguous() == True) - assert(input2.is_contiguous() == True) - assert(input3.is_contiguous() == True) + assert(input.is_contiguous() == True) + assert(vertical.is_contiguous() == True) + assert(horizontal.is_contiguous() == True) - output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() + output = input.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() - if input1.is_cuda == True: + if input.is_cuda == True: _ext.cunnex.SeparableConvolution_cuda_forward( - input1, - input2, - input3, + input, + vertical, + horizontal, output ) - elif input1.is_cuda == False: + elif input.is_cuda == False: raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED # end @@ -181,8 +181,7 @@ intPaddingLeft = int(math.floor(51 / 2.0)) intPaddingTop = int(math.floor(51 / 2.0)) intPaddingRight = int(math.floor(51 / 2.0)) intPaddingBottom = int(math.floor(51 / 2.0)) -modulePaddingFirst = torch.nn.Module() -modulePaddingSecond = torch.nn.Module() +modulePaddingInput = torch.nn.Module() modulePaddingOutput = torch.nn.Module() if True: @@ -200,24 +199,22 @@ if True: intPaddingWidth = intPaddingWidth - (intPaddingLeft + intWidth + intPaddingRight) intPaddingHeight = intPaddingHeight - (intPaddingTop + intHeight + intPaddingBottom) - modulePaddingFirst = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) - modulePaddingSecond = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) + modulePaddingInput = torch.nn.ReplicationPad2d([intPaddingLeft, intPaddingRight + intPaddingWidth, intPaddingTop, intPaddingBottom + intPaddingHeight]) modulePaddingOutput = torch.nn.ReplicationPad2d([0 - intPaddingLeft, 0 - intPaddingRight - intPaddingWidth, 0 - intPaddingTop, 0 - intPaddingBottom - intPaddingHeight]) - - modulePaddingFirst = modulePaddingFirst.cuda() - modulePaddingSecond = modulePaddingSecond.cuda() - modulePaddingOutput = modulePaddingOutput.cuda() # end if True: tensorInputFirst = tensorInputFirst.cuda() tensorInputSecond = tensorInputSecond.cuda() tensorOutput = tensorOutput.cuda() + + modulePaddingInput = modulePaddingInput.cuda() + modulePaddingOutput = modulePaddingOutput.cuda() # end if True: - variablePaddingFirst = modulePaddingFirst(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True)) - variablePaddingSecond = modulePaddingSecond(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingFirst = modulePaddingInput(torch.autograd.Variable(data=tensorInputFirst.view(1, 3, intHeight, intWidth), volatile=True)) + variablePaddingSecond = modulePaddingInput(torch.autograd.Variable(data=tensorInputSecond.view(1, 3, intHeight, intWidth), volatile=True)) variablePaddingOutput = modulePaddingOutput(moduleNetwork(variablePaddingFirst, variablePaddingSecond)) tensorOutput.resize_(3, intHeight, intWidth).copy_(variablePaddingOutput.data[0]) diff --git a/src/SeparableConvolution_cuda.c b/src/SeparableConvolution_cuda.c index a48622e..973e4c1 100644 --- a/src/SeparableConvolution_cuda.c +++ b/src/SeparableConvolution_cuda.c @@ -6,16 +6,16 @@ extern THCState* state; int SeparableConvolution_cuda_forward( - THCudaTensor* input1, - THCudaTensor* input2, - THCudaTensor* input3, + THCudaTensor* input, + THCudaTensor* vertical, + THCudaTensor* horizontal, THCudaTensor* output ) { SeparableConvolution_kernel_forward( state, - input1, - input2, - input3, + input, + vertical, + horizontal, output ); diff --git a/src/SeparableConvolution_cuda.h b/src/SeparableConvolution_cuda.h index 7320fee..a3d5ed0 100644 --- a/src/SeparableConvolution_cuda.h +++ b/src/SeparableConvolution_cuda.h @@ -1,6 +1,6 @@ int SeparableConvolution_cuda_forward( - THCudaTensor* input1, - THCudaTensor* input2, - THCudaTensor* input3, + THCudaTensor* input, + THCudaTensor* vertical, + THCudaTensor* horizontal, THCudaTensor* output );
\ No newline at end of file diff --git a/src/SeparableConvolution_kernel.cu b/src/SeparableConvolution_kernel.cu index b40786d..b4e6d59 100644 --- a/src/SeparableConvolution_kernel.cu +++ b/src/SeparableConvolution_kernel.cu @@ -17,9 +17,9 @@ __global__ void kernel_SeparableConvolution_updateOutput( const int n, - const float* input1, const long4 input1_size, const long4 input1_stride, - const float* input2, const long4 input2_size, const long4 input2_stride, - const float* input3, const long4 input3_size, const long4 input3_stride, + const float* input, const long4 input_size, const long4 input_stride, + const float* vertical, const long4 vertical_size, const long4 vertical_stride, + const float* horizontal, const long4 horizontal_size, const long4 horizontal_stride, float* output, const long4 output_size, const long4 output_stride ) { int intIndex = blockIdx.x * blockDim.x + threadIdx.x; @@ -37,7 +37,7 @@ __global__ void kernel_SeparableConvolution_updateOutput( for (int intFilterY = 0; intFilterY < 51; intFilterY += 1) { for (int intFilterX = 0; intFilterX < 51; intFilterX += 1) { - dblOutput += IDX_4(input1, intBatch, intDepth, intY + intFilterY, intX + intFilterX) * IDX_4(input2, intBatch, intFilterY, intY, intX) * IDX_4(input3, intBatch, intFilterX, intY, intX); + dblOutput += IDX_4(input, intBatch, intDepth, intY + intFilterY, intX + intFilterX) * IDX_4(vertical, intBatch, intFilterY, intY, intX) * IDX_4(horizontal, intBatch, intFilterX, intY, intX); } } @@ -46,9 +46,9 @@ __global__ void kernel_SeparableConvolution_updateOutput( void SeparableConvolution_kernel_forward( THCState* state, - THCudaTensor* input1, - THCudaTensor* input2, - THCudaTensor* input3, + THCudaTensor* input, + THCudaTensor* vertical, + THCudaTensor* horizontal, THCudaTensor* output ) { int n = 0; @@ -56,9 +56,9 @@ void SeparableConvolution_kernel_forward( n = THCudaTensor_nElement(state, output); kernel_SeparableConvolution_updateOutput<<< (n + 512 - 1) / 512, 512, 0, THCState_getCurrentStream(state) >>>( n, - THCudaTensor_data(state, input1), make_long4(input1->size[0], input1->size[1], input1->size[2], input1->size[3]), make_long4(input1->stride[0], input1->stride[1], input1->stride[2], input1->stride[3]), - THCudaTensor_data(state, input2), make_long4(input2->size[0], input2->size[1], input2->size[2], input2->size[3]), make_long4(input2->stride[0], input2->stride[1], input2->stride[2], input2->stride[3]), - THCudaTensor_data(state, input3), make_long4(input3->size[0], input3->size[1], input3->size[2], input3->size[3]), make_long4(input3->stride[0], input3->stride[1], input3->stride[2], input3->stride[3]), + THCudaTensor_data(state, input), make_long4(input->size[0], input->size[1], input->size[2], input->size[3]), make_long4(input->stride[0], input->stride[1], input->stride[2], input->stride[3]), + THCudaTensor_data(state, vertical), make_long4(vertical->size[0], vertical->size[1], vertical->size[2], vertical->size[3]), make_long4(vertical->stride[0], vertical->stride[1], vertical->stride[2], vertical->stride[3]), + THCudaTensor_data(state, horizontal), make_long4(horizontal->size[0], horizontal->size[1], horizontal->size[2], horizontal->size[3]), make_long4(horizontal->stride[0], horizontal->stride[1], horizontal->stride[2], horizontal->stride[3]), THCudaTensor_data(state, output), make_long4(output->size[0], output->size[1], output->size[2], output->size[3]), make_long4(output->stride[0], output->stride[1], output->stride[2], output->stride[3]) ); diff --git a/src/SeparableConvolution_kernel.h b/src/SeparableConvolution_kernel.h index b400579..11a0238 100644 --- a/src/SeparableConvolution_kernel.h +++ b/src/SeparableConvolution_kernel.h @@ -4,9 +4,9 @@ void SeparableConvolution_kernel_forward( THCState* state, - THCudaTensor* input1, - THCudaTensor* input2, - THCudaTensor* input3, + THCudaTensor* input, + THCudaTensor* vertical, + THCudaTensor* horizontal, THCudaTensor* output ); |
