diff --git a/src/pystencils_autodiff/_backport.py b/src/pystencils_autodiff/_backport.py index e9ba2790bac9a743e3175a464526a579eb878b00..c49103ab7c0f21ca5b08cf150671dc4a9ebe52ef 100644 --- a/src/pystencils_autodiff/_backport.py +++ b/src/pystencils_autodiff/_backport.py @@ -10,8 +10,6 @@ import itertools -import sympy as sp - import pystencils from pystencils.astnodes import KernelFunction, ResolvedFieldAccess, SympyAssignment from pystencils.interpolation_astnodes import InterpolatorAccess @@ -44,7 +42,6 @@ def compatibility_hacks(): pystencils.fields = fields KernelFunction.fields_read = property(fields_read) KernelFunction.fields_written = property(fields_written) - sp.Expr.undefined_symbols = sp.Expr.free_symbols compatibility_hacks() diff --git a/src/pystencils_autodiff/framework_integration/astnodes.py b/src/pystencils_autodiff/framework_integration/astnodes.py index f724079dc11c6975c20aab97cfc57ce2490a3ed6..815738611dd89a8d8b30ebdbb2ca177f4c7f144a 100644 --- a/src/pystencils_autodiff/framework_integration/astnodes.py +++ b/src/pystencils_autodiff/framework_integration/astnodes.py @@ -174,7 +174,7 @@ class JinjaCppFile(Node): @property def args(self): """Returns all arguments/children of this node.""" - ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, str))] + ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr, str))] iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable) and not isinstance(a, str)] return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes)) @@ -184,14 +184,14 @@ class JinjaCppFile(Node): """Set of symbols which are defined by this node.""" return set(itertools.chain.from_iterable(a.symbols_defined for a in self.args - if hasattr(a, 'symbols_defined'))) + if isinstance(a, Node))) @property def undefined_symbols(self): """Symbols which are used but are not defined inside this node.""" - return set(itertools.chain.from_iterable(a.undefined_symbols + return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols for a in self.args - if hasattr(a, 'undefined_symbols'))) - self.symbols_defined + if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined def _print(self, node): if isinstance(node, Node): @@ -296,3 +296,25 @@ class CudaErrorCheck(CustomCodeNode): err_check_function = CudaErrorCheckDefinition() required_global_declarations = [err_check_function] headers = ['<cuda.h>'] + + +class DynamicFunction(sp.Function): + """ + Function that is passed as an argument to a kernel. + Can be printed for example as `std::function` or as a functor template. + """ + + def __new__(cls, typed_function_symbol, return_dtype, *args): + return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args) + + @property + def function_dtype(self): + return self.args[0].dtype + + @property + def dtype(self): + return self.args[1].dtype + + @property + def name(self): + return self.args[0].name diff --git a/src/pystencils_autodiff/framework_integration/printer.py b/src/pystencils_autodiff/framework_integration/printer.py index cf94e9584c891d59e504d92b6ea387f830ab027a..35ac8895ebb558603b055c9744b5904a68db91be 100644 --- a/src/pystencils_autodiff/framework_integration/printer.py +++ b/src/pystencils_autodiff/framework_integration/printer.py @@ -1,7 +1,9 @@ import sympy as sp import pystencils.backends.cbackend +from pystencils.data_types import TypedSymbol from pystencils.kernelparameters import FieldPointerSymbol +from pystencils_autodiff.framework_integration.types import TemplateType class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): @@ -15,12 +17,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def __init__(self): super().__init__(dialect='c') + self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction def _print(self, node): from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile if isinstance(node, JinjaCppFile): node.printer = self - if isinstance(node, sp.Expr): return self.sympy_printer._print(node) else: @@ -40,6 +42,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): else: prefix = '#define FUNC_PREFIX static\n' kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c', with_globals=False) + template_types = sorted([x.dtype for x in node.atoms(TypedSymbol) + if isinstance(x.dtype, TemplateType)], key=str) + template_types = list(map(lambda x: 'class ' + str(x), template_types)) + if template_types: + prefix = f'{prefix}template <{",".join(template_types)}>\n' + return prefix + kernel_code def _print_FunctionCall(self, node): @@ -83,7 +91,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): if hasattr(u, 'field_name') else u.field_names[0]]), field_name=(u.field_name if hasattr(u, "field_name") else ""), - dim=("" if type(u) == FieldPointerSymbol else u.coordinate) + dim=("" if type(u) == FieldPointerSymbol else u.coordinate), + dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate]) ) ) for u in undefined_field_symbols @@ -105,14 +114,28 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): def _print_SwapBuffer(self, node): return f"""std::swap({node.first_array}, {node.second_array});""" + def _print_DynamicFunction(self, expr): + name = expr.name + arg_str = ', '.join(self._print(a) for a in expr.args[2:]) + return f'{name}({arg_str})' + class DebugFrameworkPrinter(FrameworkIntegrationPrinter): + """ + Printer with information on nodes inlined in code as comments. + + Should not be used in production, will modify your SymPy printer, destroy your whole life! + """ + + def __init__(self): + super().__init__() + self.sympy_printer._old_print = self.sympy_printer._print + self.sympy_printer.__class__._print = self._print def _print(self, node): if isinstance(node, sp.Expr): - return self.sympy_printer._print(node) + return self.sympy_printer._old_print(node) + f'/* {node.__class__.__name__}: free_symbols: {node.free_symbols} */' # noqa elif isinstance(node, pystencils.astnodes.Node): return super()._print(node) + f'/* {node.__class__.__name__} symbols_undefined: {node.undefined_symbols}, symbols_defined: {node.symbols_defined}, args {[a if isinstance(a,str) else a.__class__.__name__ for a in node.args]} */' # noqa - else: return super()._print(node) diff --git a/src/pystencils_autodiff/framework_integration/types.py b/src/pystencils_autodiff/framework_integration/types.py new file mode 100644 index 0000000000000000000000000000000000000000..f3929348fb06bcc1f711233def968161fe428aba --- /dev/null +++ b/src/pystencils_autodiff/framework_integration/types.py @@ -0,0 +1,18 @@ +# +# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. + +""" + +""" +from pystencils.data_types import Type + + +class TemplateType(Type): + + def __init__(self, name): + self._name = name + + def _sympystr(self, *args, **kwargs): + return str(self._name) diff --git a/tests/test_dynamic_function.py b/tests/test_dynamic_function.py new file mode 100644 index 0000000000000000000000000000000000000000..0423116585c4801ccf13a483a24c72580961d5be --- /dev/null +++ b/tests/test_dynamic_function.py @@ -0,0 +1,35 @@ +import sympy as sp + +import pystencils +from pystencils.data_types import TypedSymbol, create_type +from pystencils_autodiff.framework_integration.astnodes import DynamicFunction +from pystencils_autodiff.framework_integration.printer import ( + DebugFrameworkPrinter, FrameworkIntegrationPrinter) +from pystencils_autodiff.framework_integration.types import TemplateType + + +def test_dynamic_function(): + x, y = pystencils.fields('x, y: float32[3d]') + + a = sp.symbols('a') + + my_fun_call = DynamicFunction(TypedSymbol('my_fun', + 'std::function<double(double)>'), create_type('double'), a) + + assignments = pystencils.AssignmentCollection({ + y.center: x.center + my_fun_call + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + + template_fun_call = DynamicFunction(TypedSymbol('my_fun', + TemplateType('Functor_T')), create_type('double'), a, x.center) + + assignments = pystencils.AssignmentCollection({ + y.center: x.center + template_fun_call + }) + + ast = pystencils.create_kernel(assignments) + pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter()) + pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())