Skip to content
Snippets Groups Projects
Commit f65844f4 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Implement DynamicFunction, that can be passed as a parameter/template to the kernel

parent 721e60cf
Branches
Tags
No related merge requests found
Pipeline #21819 failed
...@@ -10,8 +10,6 @@ ...@@ -10,8 +10,6 @@
import itertools import itertools
import sympy as sp
import pystencils import pystencils
from pystencils.astnodes import KernelFunction, ResolvedFieldAccess, SympyAssignment from pystencils.astnodes import KernelFunction, ResolvedFieldAccess, SympyAssignment
from pystencils.interpolation_astnodes import InterpolatorAccess from pystencils.interpolation_astnodes import InterpolatorAccess
...@@ -44,7 +42,6 @@ def compatibility_hacks(): ...@@ -44,7 +42,6 @@ def compatibility_hacks():
pystencils.fields = fields pystencils.fields = fields
KernelFunction.fields_read = property(fields_read) KernelFunction.fields_read = property(fields_read)
KernelFunction.fields_written = property(fields_written) KernelFunction.fields_written = property(fields_written)
sp.Expr.undefined_symbols = sp.Expr.free_symbols
compatibility_hacks() compatibility_hacks()
...@@ -174,7 +174,7 @@ class JinjaCppFile(Node): ...@@ -174,7 +174,7 @@ class JinjaCppFile(Node):
@property @property
def args(self): def args(self):
"""Returns all arguments/children of this node.""" """Returns all arguments/children of this node."""
ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, str))] ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, (Node, sp.Expr, str))]
iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable) iterables_of_ast_nodes = [a for a in self.ast_dict.values() if isinstance(a, Iterable)
and not isinstance(a, str)] and not isinstance(a, str)]
return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes)) return ast_nodes + list(itertools.chain.from_iterable(iterables_of_ast_nodes))
...@@ -184,14 +184,14 @@ class JinjaCppFile(Node): ...@@ -184,14 +184,14 @@ class JinjaCppFile(Node):
"""Set of symbols which are defined by this node.""" """Set of symbols which are defined by this node."""
return set(itertools.chain.from_iterable(a.symbols_defined return set(itertools.chain.from_iterable(a.symbols_defined
for a in self.args for a in self.args
if hasattr(a, 'symbols_defined'))) if isinstance(a, Node)))
@property @property
def undefined_symbols(self): def undefined_symbols(self):
"""Symbols which are used but are not defined inside this node.""" """Symbols which are used but are not defined inside this node."""
return set(itertools.chain.from_iterable(a.undefined_symbols return set(itertools.chain.from_iterable(a.undefined_symbols if isinstance(a, Node) else a.free_symbols
for a in self.args for a in self.args
if hasattr(a, 'undefined_symbols'))) - self.symbols_defined if isinstance(a, (Node, sp.Expr)))) - self.symbols_defined
def _print(self, node): def _print(self, node):
if isinstance(node, Node): if isinstance(node, Node):
...@@ -296,3 +296,25 @@ class CudaErrorCheck(CustomCodeNode): ...@@ -296,3 +296,25 @@ class CudaErrorCheck(CustomCodeNode):
err_check_function = CudaErrorCheckDefinition() err_check_function = CudaErrorCheckDefinition()
required_global_declarations = [err_check_function] required_global_declarations = [err_check_function]
headers = ['<cuda.h>'] headers = ['<cuda.h>']
class DynamicFunction(sp.Function):
"""
Function that is passed as an argument to a kernel.
Can be printed for example as `std::function` or as a functor template.
"""
def __new__(cls, typed_function_symbol, return_dtype, *args):
return sp.Function.__new__(cls, typed_function_symbol, return_dtype, *args)
@property
def function_dtype(self):
return self.args[0].dtype
@property
def dtype(self):
return self.args[1].dtype
@property
def name(self):
return self.args[0].name
import sympy as sp import sympy as sp
import pystencils.backends.cbackend import pystencils.backends.cbackend
from pystencils.data_types import TypedSymbol
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from pystencils_autodiff.framework_integration.types import TemplateType
class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
...@@ -15,12 +17,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -15,12 +17,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
def __init__(self): def __init__(self):
super().__init__(dialect='c') super().__init__(dialect='c')
self.sympy_printer.__class__._print_DynamicFunction = self._print_DynamicFunction
def _print(self, node): def _print(self, node):
from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile from pystencils_autodiff.framework_integration.astnodes import JinjaCppFile
if isinstance(node, JinjaCppFile): if isinstance(node, JinjaCppFile):
node.printer = self node.printer = self
if isinstance(node, sp.Expr): if isinstance(node, sp.Expr):
return self.sympy_printer._print(node) return self.sympy_printer._print(node)
else: else:
...@@ -40,6 +42,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -40,6 +42,12 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
else: else:
prefix = '#define FUNC_PREFIX static\n' prefix = '#define FUNC_PREFIX static\n'
kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c', with_globals=False) kernel_code = pystencils.backends.cbackend.generate_c(node, dialect='c', with_globals=False)
template_types = sorted([x.dtype for x in node.atoms(TypedSymbol)
if isinstance(x.dtype, TemplateType)], key=str)
template_types = list(map(lambda x: 'class ' + str(x), template_types))
if template_types:
prefix = f'{prefix}template <{",".join(template_types)}>\n'
return prefix + kernel_code return prefix + kernel_code
def _print_FunctionCall(self, node): def _print_FunctionCall(self, node):
...@@ -83,7 +91,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -83,7 +91,8 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
if hasattr(u, 'field_name') if hasattr(u, 'field_name')
else u.field_names[0]]), else u.field_names[0]]),
field_name=(u.field_name if hasattr(u, "field_name") else ""), field_name=(u.field_name if hasattr(u, "field_name") else ""),
dim=("" if type(u) == FieldPointerSymbol else u.coordinate) dim=("" if type(u) == FieldPointerSymbol else u.coordinate),
dim_letter=("" if type(u) == FieldPointerSymbol else 'xyz'[u.coordinate])
) )
) )
for u in undefined_field_symbols for u in undefined_field_symbols
...@@ -105,14 +114,28 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend): ...@@ -105,14 +114,28 @@ class FrameworkIntegrationPrinter(pystencils.backends.cbackend.CBackend):
def _print_SwapBuffer(self, node): def _print_SwapBuffer(self, node):
return f"""std::swap({node.first_array}, {node.second_array});""" return f"""std::swap({node.first_array}, {node.second_array});"""
def _print_DynamicFunction(self, expr):
name = expr.name
arg_str = ', '.join(self._print(a) for a in expr.args[2:])
return f'{name}({arg_str})'
class DebugFrameworkPrinter(FrameworkIntegrationPrinter): class DebugFrameworkPrinter(FrameworkIntegrationPrinter):
"""
Printer with information on nodes inlined in code as comments.
Should not be used in production, will modify your SymPy printer, destroy your whole life!
"""
def __init__(self):
super().__init__()
self.sympy_printer._old_print = self.sympy_printer._print
self.sympy_printer.__class__._print = self._print
def _print(self, node): def _print(self, node):
if isinstance(node, sp.Expr): if isinstance(node, sp.Expr):
return self.sympy_printer._print(node) return self.sympy_printer._old_print(node) + f'/* {node.__class__.__name__}: free_symbols: {node.free_symbols} */' # noqa
elif isinstance(node, pystencils.astnodes.Node): elif isinstance(node, pystencils.astnodes.Node):
return super()._print(node) + f'/* {node.__class__.__name__} symbols_undefined: {node.undefined_symbols}, symbols_defined: {node.symbols_defined}, args {[a if isinstance(a,str) else a.__class__.__name__ for a in node.args]} */' # noqa return super()._print(node) + f'/* {node.__class__.__name__} symbols_undefined: {node.undefined_symbols}, symbols_defined: {node.symbols_defined}, args {[a if isinstance(a,str) else a.__class__.__name__ for a in node.args]} */' # noqa
else: else:
return super()._print(node) return super()._print(node)
#
# Copyright © 2020 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
from pystencils.data_types import Type
class TemplateType(Type):
def __init__(self, name):
self._name = name
def _sympystr(self, *args, **kwargs):
return str(self._name)
import sympy as sp
import pystencils
from pystencils.data_types import TypedSymbol, create_type
from pystencils_autodiff.framework_integration.astnodes import DynamicFunction
from pystencils_autodiff.framework_integration.printer import (
DebugFrameworkPrinter, FrameworkIntegrationPrinter)
from pystencils_autodiff.framework_integration.types import TemplateType
def test_dynamic_function():
x, y = pystencils.fields('x, y: float32[3d]')
a = sp.symbols('a')
my_fun_call = DynamicFunction(TypedSymbol('my_fun',
'std::function<double(double)>'), create_type('double'), a)
assignments = pystencils.AssignmentCollection({
y.center: x.center + my_fun_call
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
template_fun_call = DynamicFunction(TypedSymbol('my_fun',
TemplateType('Functor_T')), create_type('double'), a, x.center)
assignments = pystencils.AssignmentCollection({
y.center: x.center + template_fun_call
})
ast = pystencils.create_kernel(assignments)
pystencils.show_code(ast, custom_backend=FrameworkIntegrationPrinter())
pystencils.show_code(ast, custom_backend=DebugFrameworkPrinter())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment