summaryrefslogtreecommitdiff
path: root/install.py
diff options
context:
space:
mode:
authorsniklaus <simon.niklaus@outlook.com>2017-09-09 22:59:59 -0700
committersniklaus <simon.niklaus@outlook.com>2017-09-09 22:59:59 -0700
commitcb73882b7f6b48f4ba73426b140e77d0d1d97468 (patch)
treeb2a45d643d3703e489ae2fd18ffd1143b4c7df3e /install.py
no message
Diffstat (limited to 'install.py')
-rw-r--r--install.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/install.py b/install.py
new file mode 100644
index 0000000..42847fa
--- /dev/null
+++ b/install.py
@@ -0,0 +1,33 @@
+import os
+import torch
+import torch.utils.ffi
+
+strBasepath = os.path.split(os.path.abspath(__file__))[0] + '/'
+strHeaders = []
+strSources = []
+strDefines = []
+strObjects = []
+
+if torch.cuda.is_available() == True:
+ strHeaders += ['src/SeparableConvolution_cuda.h']
+ strSources += ['src/SeparableConvolution_cuda.c']
+ strDefines += [('WITH_CUDA', None)]
+ strObjects += ['src/SeparableConvolution_kernel.o']
+# end
+
+objectExtension = torch.utils.ffi.create_extension(
+ name='_ext.cunnex',
+ headers=strHeaders,
+ sources=strSources,
+ verbose=False,
+ with_cuda=any(strDefine[0] == 'WITH_CUDA' for strDefine in strDefines),
+ package=False,
+ relative_to=strBasepath,
+ include_dirs=[os.path.expandvars('$CUDA_HOME') + '/include'],
+ define_macros=strDefines,
+ extra_objects=[os.path.join(strBasepath, strObject) for strObject in strObjects]
+)
+
+if __name__ == '__main__':
+ objectExtension.build()
+# end \ No newline at end of file