Skip to content
Snippets Groups Projects
Commit 26df2dfb authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add some extra flags to TorchModule.compile and remove GOOGLE_CUDA from Template

parent c143c626
No related branches found
No related tags found
No related merge requests found
......@@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile):
return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)),
function_name='call_' + kernel_ast.function_name)
def compile(self):
def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None):
from torch.utils.cpp_extension import load
file_extension = '.cu' if self.is_cuda else '.cpp'
source_code = str(self)
......@@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile):
os.environ['CXX'] = get_compiler_config()['command']
torch_extension = load(hash,
[file_name],
with_cuda=self.is_cuda,
extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']],
[file_name] + extra_source_files,
with_cuda=self.is_cuda or with_cuda,
extra_cflags=['--std=c++14', get_compiler_config()
['flags'].replace('--std=c++11', '')],
extra_cuda_cflags=['-std=c++14', '-ccbin',
get_compiler_config()['command']] + extra_cuda_flags,
build_directory=build_dir,
extra_include_paths=[get_pycuda_include_path(),
get_pystencils_include_path(),
......
#define RESTRICT __restrict__
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif
//#if GOOGLE_CUDA
//#define EIGEN_USE_GPU
//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
//#endif
{% for header in headers -%}
#include {{ header }}
......
......@@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile):
class CustomFunctionCall(JinjaCppFile):
TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined)
def __init__(self, function_name, *args, fields_accessed=[]):
def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None):
ast_dict = {
'function_name': function_name,
'args': args,
'fields_accessed': [f.center for f in fields_accessed]
}
super().__init__(ast_dict)
self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)]
if custom_signature:
self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())]
else:
self.required_global_declarations = [CustomFunctionDeclaration(
self.ast_dict.function_name, self.ast_dict.args)]
@property
def symbols_defined(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment