From cb73882b7f6b48f4ba73426b140e77d0d1d97468 Mon Sep 17 00:00:00 2001 From: sniklaus Date: Sat, 9 Sep 2017 22:59:59 -0700 Subject: no message --- SeparableConvolution.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 SeparableConvolution.py (limited to 'SeparableConvolution.py') diff --git a/SeparableConvolution.py b/SeparableConvolution.py new file mode 100644 index 0000000..9e5d1e9 --- /dev/null +++ b/SeparableConvolution.py @@ -0,0 +1,48 @@ +import torch + +import _ext.cunnex + +class SeparableConvolution(torch.autograd.Function): + def __init__(self): + super(SeparableConvolution, self).__init__() + # end + + def forward(self, input1, input2, input3): + intBatches = input1.size(0) + intInputDepth = input1.size(1) + intInputHeight = input1.size(2) + intInputWidth = input1.size(3) + intFilterSize = min(input2.size(1), input3.size(1)) + intOutputHeight = min(input2.size(2), input3.size(2)) + intOutputWidth = min(input2.size(3), input3.size(3)) + + assert(intInputHeight - 51 == intOutputHeight - 1) + assert(intInputWidth - 51 == intOutputWidth - 1) + assert(intFilterSize == 51) + + assert(input1.is_contiguous() == True) + assert(input2.is_contiguous() == True) + assert(input3.is_contiguous() == True) + + output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() + + if input1.is_cuda == True: + _ext.cunnex.SeparableConvolution_cuda_forward( + input1, + input2, + input3, + output + ) + + elif input1.is_cuda == False: + assert(False) # NOT IMPLEMENTED + + # end + + return output + # end + + def backward(self, gradOutput): + assert(false) # NOT IMPLEMENTED + # end +# end \ No newline at end of file -- cgit v1.2.3-70-g09d2