summaryrefslogtreecommitdiff
path: root/SeparableConvolution.py
diff options
context:
space:
mode:
authorsniklaus <simon.niklaus@outlook.com>2017-09-18 22:20:04 -0700
committersniklaus <simon.niklaus@outlook.com>2017-09-18 22:20:04 -0700
commitcfd6a91a628c603eeeecf517340ac0474a126496 (patch)
tree7c891bc8660904771bb612f56aca8f22760d0cdb /SeparableConvolution.py
parente123297d61dc9915b70060def498560ca5d3d073 (diff)
no message
Diffstat (limited to 'SeparableConvolution.py')
-rw-r--r--SeparableConvolution.py34
1 files changed, 17 insertions, 17 deletions
diff --git a/SeparableConvolution.py b/SeparableConvolution.py
index 2ac31f9..d97a87f 100644
--- a/SeparableConvolution.py
+++ b/SeparableConvolution.py
@@ -7,34 +7,34 @@ class SeparableConvolution(torch.autograd.Function):
super(SeparableConvolution, self).__init__()
# end
- def forward(self, input1, input2, input3):
- intBatches = input1.size(0)
- intInputDepth = input1.size(1)
- intInputHeight = input1.size(2)
- intInputWidth = input1.size(3)
- intFilterSize = min(input2.size(1), input3.size(1))
- intOutputHeight = min(input2.size(2), input3.size(2))
- intOutputWidth = min(input2.size(3), input3.size(3))
+ def forward(self, input, vertical, horizontal):
+ intBatches = input.size(0)
+ intInputDepth = input.size(1)
+ intInputHeight = input.size(2)
+ intInputWidth = input.size(3)
+ intFilterSize = min(vertical.size(1), horizontal.size(1))
+ intOutputHeight = min(vertical.size(2), horizontal.size(2))
+ intOutputWidth = min(vertical.size(3), horizontal.size(3))
assert(intInputHeight - 51 == intOutputHeight - 1)
assert(intInputWidth - 51 == intOutputWidth - 1)
assert(intFilterSize == 51)
- assert(input1.is_contiguous() == True)
- assert(input2.is_contiguous() == True)
- assert(input3.is_contiguous() == True)
+ assert(input.is_contiguous() == True)
+ assert(vertical.is_contiguous() == True)
+ assert(horizontal.is_contiguous() == True)
- output = input1.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_()
+ output = input.new().resize_(intBatches, intInputDepth, intOutputHeight, intOutputWidth).zero_()
- if input1.is_cuda == True:
+ if input.is_cuda == True:
_ext.cunnex.SeparableConvolution_cuda_forward(
- input1,
- input2,
- input3,
+ input,
+ vertical,
+ horizontal,
output
)
- elif input1.is_cuda == False:
+ elif input.is_cuda == False:
raise NotImplementedError() # CPU VERSION NOT IMPLEMENTED
# end