diff options
| author | sniklaus <simon.niklaus@outlook.com> | 2017-09-18 22:20:04 -0700 |
|---|---|---|
| committer | sniklaus <simon.niklaus@outlook.com> | 2017-09-18 22:20:04 -0700 |
| commit | cfd6a91a628c603eeeecf517340ac0474a126496 (patch) | |
| tree | 7c891bc8660904771bb612f56aca8f22760d0cdb /SeparableConvolution.py | |
| parent | e123297d61dc9915b70060def498560ca5d3d073 (diff) | |
no message
Diffstat (limited to 'SeparableConvolution.py')
| -rw-r--r-- | SeparableConvolution.py | 34 |
1 files changed, 17 insertions, 17 deletions
diff --git a/SeparableConvolution.py b/SeparableConvolution.py index 2ac31f9..d97a87f 100644 --- a/SeparableConvolution.py +++ b/SeparableConvolution.py @@ -7,34 +7,34 @@ class SeparableConvolution(torch.autograd.Function): super(SeparableConvolution, self).__init__() # end - def forward(self, input1, input2, input3): - intBatches = input1.size(0) - intInputDepth = input1.size(1) - intInputHeight = input1.size(2) - intInputWidth = input1.size(3) - intFilterSize = min(input2.size(1), input3.size(1)) - intOutputHeight = min(input2.size(2), input3.size(2)) - intOutputWidth = min(input2.size(3), input3.size(3)) + def forward(self, input, vertical, horizontal): + intBatches = input.size(0) + intInputDepth = input.size(1) + intInputHeight = input.size(2) + intInputWidth = input.size(3) + intFilterSize = min(vertical.size(1), horizontal.size(1)) + intOutputHeight = min(vertical.size(2), horizontal.size(2)) + intOutputWidth = min(vertical.size(3), horizontal.size(3)) assert(intInputHeight - 51 == intOutputHeight - 1) assert(intInputWidth - 51 == intOutputWidth - 1) assert(intFilterSize == 51) - assert(input1.is_contiguous() == True) - assert(input2.is_contiguous() == True) - assert(input3.is_contiguous() == True) + assert(input.is_contiguous() == True) + assert(vertical.is_contiguous() == True) + assert(horizontal.is_contiguous() == True) - output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() + output = input.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_() - if input1.is_cuda == True: + if input.is_cuda == True: _ext.cunnex.SeparableConvolution_cuda_forward( - input1, - input2, - input3, + input, + vertical, + horizontal, output ) - elif input1.is_cuda == False: + elif input.is_cuda == False: raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED # end |
