Skip to content
Snippets Groups Projects
Commit 26df2dfb authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add some extra flags to TorchModule.compile and remove GOOGLE_CUDA from Template

parent c143c626
Branches
Tags
No related merge requests found
...@@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile): ...@@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name) function_name='call_' + kernel_ast.function_name)
def compile(self): def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
from torch.utils.cpp_extension import load from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp' file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self) source_code = str(self)
...@@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile): ...@@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile):
os.environ['CXX'] = get_compiler_config()['command'] os.environ['CXX'] = get_compiler_config()['command']
torch_extension = load(hash, torch_extension = load(hash,
[file_name], [file_name] + extra_source_files,
with_cuda=self.is_cuda, with_cuda=self.is_cuda or with_cuda,
extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')], extra_cflags=['--std=c++14', get_compiler_config()
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']], ['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin',
get_compiler_config()['command']] + extra_cuda_flags,
build_directory=build_dir, build_directory=build_dir,
extra_include_paths=[get_pycuda_include_path(), extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path(), get_pystencils_include_path(),
......
#define RESTRICT __restrict__ #define RESTRICT __restrict__
#if GOOGLE_CUDA //#if GOOGLE_CUDA
#define EIGEN_USE_GPU //#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" //#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif //#endif
{% for header in headers -%} {% for header in headers -%}
#include {{ header }} #include {{ header }}
......
...@@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile): ...@@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile):
class CustomFunctionCall(JinjaCppFile): class CustomFunctionCall(JinjaCppFile):
TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined) TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
def __init__(self, function_name, *args, fields_accessed=[]): def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None):
ast_dict = { ast_dict = {
'function_name': function_name, 'function_name': function_name,
'args': args, 'args': args,
'fields_accessed': [f.center for f in fields_accessed] 'fields_accessed': [f.center for f in fields_accessed]
} }
super().__init__(ast_dict) super().__init__(ast_dict)
self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)] if custom_signature:
self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())]
else:
self.required_global_declarations = [CustomFunctionDeclaration(
self.ast_dict.function_name, self.ast_dict.args)]
@property @property
def symbols_defined(self): def symbols_defined(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment