diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index e853a4e765c8db24e4401b6ffb33207e6fbc8e17..db753059432a78a0aaff49b14d3bc96cddd95dfe 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -137,7 +137,7 @@ class TorchModule(JinjaCppFile): return WrapperFunction(cls.DESTRUCTURING_CLASS(generate_kernel_call(kernel_ast)), function_name='call_' + kernel_ast.function_name) - def compile(self): + def compile(self, extra_source_files=[], extra_cuda_flags=[], with_cuda=None): from torch.utils.cpp_extension import load file_extension = '.cu' if self.is_cuda else '.cpp' source_code = str(self) @@ -155,10 +155,12 @@ class TorchModule(JinjaCppFile): os.environ['CXX'] = get_compiler_config()['command'] torch_extension = load(hash, - [file_name], - with_cuda=self.is_cuda, - extra_cflags=['--std=c++14', get_compiler_config()['flags'].replace('--std=c++11', '')], - extra_cuda_cflags=['-std=c++14', '-ccbin', get_compiler_config()['command']], + [file_name] + extra_source_files, + with_cuda=self.is_cuda or with_cuda, + extra_cflags=['--std=c++14', get_compiler_config() + ['flags'].replace('--std=c++11', '')], + extra_cuda_cflags=['-std=c++14', '-ccbin', + get_compiler_config()['command']] + extra_cuda_flags, build_directory=build_dir, extra_include_paths=[get_pycuda_include_path(), get_pystencils_include_path(), diff --git a/src/pystencils_autodiff/backends/module.tmpl.cpp b/src/pystencils_autodiff/backends/module.tmpl.cpp index f7f7ebfcb15500be81095634fb1cc1342a9978d4..9ea13c3c20ae7c06b6f23deb562559fd18a77ca6 100644 --- a/src/pystencils_autodiff/backends/module.tmpl.cpp +++ b/src/pystencils_autodiff/backends/module.tmpl.cpp @@ -1,9 +1,9 @@ #define RESTRICT __restrict__ -#if GOOGLE_CUDA -#define EIGEN_USE_GPU -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#endif +//#if GOOGLE_CUDA +//#define EIGEN_USE_GPU +//#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +//#endif {% for header in headers -%} #include {{ header }} diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 8e9cbba59a2be98672575bdd41846b6387ef8765..75dd5a1ca6b1d0484533d3264a11b94f2928081d 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -369,14 +369,18 @@ class CustomFunctionDeclaration(JinjaCppFile): class CustomFunctionCall(JinjaCppFile): TEMPLATE = jinja2.Template("""{{function_name}}({{ args | join(', ') }});""", undefined=jinja2.StrictUndefined) - def __init__(self, function_name, *args, fields_accessed=[]): + def __init__(self, function_name, *args, fields_accessed=[], custom_signature=None): ast_dict = { 'function_name': function_name, 'args': args, 'fields_accessed': [f.center for f in fields_accessed] } super().__init__(ast_dict) - self.required_global_declarations = [CustomFunctionDeclaration(self.ast_dict.function_name, self.ast_dict.args)] + if custom_signature: + self.required_global_declarations = [CustomCodeNode(custom_signature, (), ())] + else: + self.required_global_declarations = [CustomFunctionDeclaration( + self.ast_dict.function_name, self.ast_dict.args)] @property def symbols_defined(self):