Skip to content
Snippets Groups Projects
Commit 92ae036d authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Add Header AST node

parent 703816fe
No related branches found
No related tags found
No related merge requests found
Pipeline #21959 failed
......@@ -31,6 +31,28 @@ def get_cubic_interpolation_include_paths():
join(dirname(pystencils.gpucuda.__file__), 'CubicInterpolationCUDA', 'code', 'internal')]
class Header(JinjaCppFile):
TEMPLATE = read_template_from_file(join(dirname(__file__), 'module.tmpl.hpp'))
def __init__(self, exported_functions, module_name):
ast_dict = {
'declarations': exported_functions,
'module_name': module_name,
}
super().__init__(ast_dict)
def __str__(self):
self.printer._signatureOnly = True
rtn = JinjaCppFile.__str__(self)
self.printer._signatureOnly = False
return rtn
@property
def backend(self):
return 'c'
class TorchTensorDestructuring(DestructuringBindingsForFieldClass):
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data_ptr<{dtype}>()",
......@@ -143,6 +165,10 @@ class TorchModule(JinjaCppFile):
*get_cubic_interpolation_include_paths()])
return torch_extension
@property
def header(self):
return Header(self.ast_dict.kernel_wrappers, self.module_name)
class TensorflowModule(TorchModule):
DESTRUCTURING_CLASS = TensorflowTensorDestructuring
......
/*
* {{ module_name }}.hpp
* Copyright (C) 2020 Stephan Seitz <stephan.seitz@fau.de>
*
* Distributed under terms of the GPLv3 license.
*/
#pragma once
{{ declarations | join('\n\n') }}
......@@ -174,7 +174,7 @@ class JinjaCppFile(Node):
@property
def args(self):
"""Returns all arguments/children of this node."""
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr, str))]
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr))]
iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
and not isinstance(a, str)]
return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
......@@ -194,7 +194,7 @@ class JinjaCppFile(Node):
if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
def _print(self, node):
if isinstance(node, Node):
if isinstance(node, (Node, sp.Expr)):
return self.printer(node)
else:
return str(node)
......@@ -345,3 +345,8 @@ class MeshNormalFunctor(DynamicFunction):
def __getnewargs__(self):
return self.mesh_name, self.dtype.base_dtype, self.args[2:]
@property
def name(self):
return self.mesh_name
......@@ -18,6 +18,7 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
def __init__(self):
super().__init__(dialect='c')
self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction
self.sympy_printer.__class__._print_MeshNormalFunctor = self._print_DynamicFunction
def _print(self, node):
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
......@@ -30,6 +31,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
def _print_WrapperFunction(self, node):
super_result = super()._print_KernelFunction(node)
if self._signatureOnly:
super_result += ';'
return super_result.replace('FUNC_PREFIX ', '')
def _print_TextureDeclaration(self, node):
......@@ -47,8 +50,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
template_types = list(map(lambda x: 'class ' + str(x), template_types))
if template_types:
prefix = f'{prefix}template <{",".join(template_types)}>\n'
if self._signatureOnly:
suffix = ';'
else:
suffix = ''
return prefix + kernel_code
return prefix + kernel_code + suffix
def _print_FunctionCall(self, node):
......@@ -119,6 +126,14 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
arg_str = ', '.join(self._print(a) for a in expr.args[2:])
return f'{name}({arg_str})'
def _print_CustomCodeNode(self, node):
super_code = super()._print_CustomCodeNode(node)
if super_code:
# Without leading new line
return super_code[1:]
else:
return super_code
class DebugFrameworkPrinter(FrameworkIntegrationPrinter):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment