diff options
| author | sniklaus <simon.niklaus@outlook.com> | 2017-09-09 22:59:59 -0700 |
|---|---|---|
| committer | sniklaus <simon.niklaus@outlook.com> | 2017-09-09 22:59:59 -0700 |
| commit | cb73882b7f6b48f4ba73426b140e77d0d1d97468 (patch) | |
| tree | b2a45d643d3703e489ae2fd18ffd1143b4c7df3e /SeparableConvolution.py | |
no message
Diffstat (limited to 'SeparableConvolution.py')
| -rw-r--r-- | SeparableConvolution.py | 48 |
1 files changed, 48 insertions, 0 deletions
diff --git a/SeparableConvolution.py b/SeparableConvolution.py new file mode 100644 index 0000000..9e5d1e9 --- /dev/null +++ b/SeparableConvolution.py @@ -0,0 +1,48 @@ +import torch + +import _ext.cunnex + +class SeparableConvolution(torch.autograd.Function): + def __init__(self): + super(SeparableConvolution, self).__init__() + # end + + def forward(self, input1, input2, input3): + intBatches = input1.size(0) + intInputDepth = input1.size(1) + intInputHeight = input1.size(2) + intInputWidth = input1.size(3) + intFilterSize = min(input2.size(1), input3.size(1)) + intOutputHeight = min(input2.size(2), input3.size(2)) + intOutputWidth = min(input2.size(3), input3.size(3)) + + assert(intInputHeight - 51 == intOutputHeight - 1) + assert(intInputWidth - 51 == intOutputWidth - 1) + assert(intFilterSize == 51) + + assert(input1.is_contiguous() == True) + assert(input2.is_contiguous() == True) + assert(input3.is_contiguous() == True) + + output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() + + if input1.is_cuda == True: + _ext.cunnex.SeparableConvolution_cuda_forward( + input1, + input2, + input3, + output + ) + + elif input1.is_cuda == False: + assert(False) # NOT IMPLEMENTED + + # end + + return output + # end + + def backward(self, gradOutput): + assert(false) # NOT IMPLEMENTED + # end +# end
\ No newline at end of file |
