diff --git a/src/pystencils_autodiff/backends/astnodes.py b/src/pystencils_autodiff/backends/astnodes.py index 3f3b0fa39be4fc27c54456714c9e649eb3c58ec8..dd65c4e3ea76024d0c1a03b63b8a36cd8eae183b 100644 --- a/src/pystencils_autodiff/backends/astnodes.py +++ b/src/pystencils_autodiff/backends/astnodes.py @@ -31,6 +31,28 @@ def get_cubic_interpolation_include_paths(): join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')] +class Header(JinjaCppFile): + TEMPLATE = read_template_from_file(join(dirname(__file__), 'module.tmpl.hpp')) + + def __init__(self, exported_functions, module_name): + ast_dict = { + 'declarations': exported_functions, + 'module_name': module_name, + } + + super().__init__(ast_dict) + + def __str__(self): + self.printer._signatureOnly = True + rtn = JinjaCppFile.__str__(self) + self.printer._signatureOnly = False + return rtn + + @property + def backend(self): + return 'c' + + class TorchTensorDestructuring(DestructuringBindingsForFieldClass): CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data_ptr<{dtype}>()", @@ -143,6 +165,10 @@ class TorchModule(JinjaCppFile): *get_cubic_interpolation_include_paths()]) return torch_extension + @property + def header(self): + return Header(self.ast_dict.kernel_wrappers, self.module_name) + class TensorflowModule(TorchModule): DESTRUCTURING_CLASS = TensorflowTensorDestructuring diff --git a/src/pystencils_autodiff/backends/module.tmpl.hpp b/src/pystencils_autodiff/backends/module.tmpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08bbffed59cd0dd632aa48fce414f5359279e05c --- /dev/null +++ b/src/pystencils_autodiff/backends/module.tmpl.hpp @@ -0,0 +1,11 @@ +/* + * {{ module_name }}.hpp + * Copyright (C) 2020 Stephan Seitz <stephan.seitz@fau.de> + * + * Distributed under terms of the GPLv3 license. + */ + +#pragma once + +{{ declarations | join('\n\n') }} + diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index 4b0b8300e5665b5e9f55fd4a76312394bdc6f9ce..35620f00cc9d7f0145332af3df2b58a393dded53 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -174,7 +174,7 @@ class JinjaCppFile(Node): @property def args(self): """Returns all arguments/children of this node.""" - ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr, str))] + ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))] iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable) and not isinstance(a, str)] return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes)) @@ -194,7 +194,7 @@ class JinjaCppFile(Node): if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined def _print(self, node): - if isinstance(node, Node): + if isinstance(node, (Node, sp.Expr)): return self.printer(node) else: return str(node) @@ -345,3 +345,8 @@ class MeshNormalFunctor(DynamicFunction): def __getnewargs__(self): return self.mesh_name, self.dtype.base_dtype, self.args[2:] + + @property + def name(self): + return self.mesh_name + diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index 35ac8895ebb558603b055c9744b5904a68db91be..f038f8598ad95b8ead298871a061152bc5c3c7ba 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -18,6 +18,7 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def __init__(self): super().__init__(dialect='c') self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction + self.sympy_printer.__class__._print_MeshNormalFunctor = self._print_DynamicFunction def _print(self, node): from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile @@ -30,6 +31,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_WrapperFunction(self, node): super_result = super()._print_KernelFunction(node) + if self._signatureOnly: + super_result += ';' return super_result.replace('FUNC_PREFIX ', '') def _print_TextureDeclaration(self, node): @@ -47,8 +50,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): template_types = list(map(lambda x: 'class ' + str(x), template_types)) if template_types: prefix = f'{prefix}template <{",".join(template_types)}>\n' + if self._signatureOnly: + suffix = ';' + else: + suffix = '' - return prefix + kernel_code + return prefix + kernel_code + suffix def _print_FunctionCall(self, node): @@ -119,6 +126,14 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): arg_str = ', '.join(self._print(a) for a in expr.args[2:]) return f'{name}({arg_str})' + def _print_CustomCodeNode(self, node): + super_code = super()._print_CustomCodeNode(node) + if super_code: + # Without leading new line + return super_code[1:] + else: + return super_code + class DebugFrameworkPrinter(FrameworkIntegrationPrinter): """