summaryrefslogtreecommitdiff
path: root/install.py
blob: 42847fa1616664b9846b1be112d2170d14995315 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import os
import torch
import torch.utils.ffi

strBasepath = os.path.split(os.path.abspath(__file__))[0] + '/'
strHeaders = []
strSources = []
strDefines = []
strObjects = []

if torch.cuda.is_available() == True:
	strHeaders += ['src/SeparableConvolution_cuda.h']
	strSources += ['src/SeparableConvolution_cuda.c']
	strDefines += [('WITH_CUDA', None)]
	strObjects += ['src/SeparableConvolution_kernel.o']
# end

objectExtension = torch.utils.ffi.create_extension(
	name='_ext.cunnex',
	headers=strHeaders,
	sources=strSources,
	verbose=False,
	with_cuda=any(strDefine[0] == 'WITH_CUDA' for strDefine in strDefines),
	package=False,
	relative_to=strBasepath,
	include_dirs=[os.path.expandvars('$CUDA_HOME') + '/include'],
	define_macros=strDefines,
	extra_objects=[os.path.join(strBasepath, strObject) for strObject in strObjects]
)

if __name__ == '__main__':
	objectExtension.build()
# end