From 53986db2d297238321aafa4f4c434261294ff756 Mon Sep 17 00:00:00 2001 From: Stephan Seitz <stephan.seitz@fau.de> Date: Sat, 13 Jul 2019 01:17:28 +0200 Subject: [PATCH] Fixup for DestructuringBindingsForFieldClass - rename header Field.h is not a unique name in waLBerla context - add PyStencilsField.h - bindings were lacking data type --- pystencils/astnodes.py | 8 ++--- pystencils/backends/cbackend.py | 12 ++++---- pystencils/include/PyStencilsField.h | 19 ++++++++++++ .../test_destructuring_field_class.py | 30 ++++++++++++++++++- 4 files changed, 58 insertions(+), 11 deletions(-) create mode 100644 pystencils/include/PyStencilsField.h diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 2d3174a1a..83b12f4b0 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -653,10 +653,10 @@ class DestructuringBindingsForFieldClass(Node): """ CLASS_TO_MEMBER_DICT = { FieldPointerSymbol: "data", - FieldShapeSymbol: "shape", - FieldStrideSymbol: "stride" + FieldShapeSymbol: "shape[%i]", + FieldStrideSymbol: "stride[%i]" } - CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>") + CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") @property def fields_accessed(self) -> Set['ResolvedFieldAccess']: @@ -665,7 +665,7 @@ class DestructuringBindingsForFieldClass(Node): def __init__(self, body): super(DestructuringBindingsForFieldClass, self).__init__() - self.headers = ['<Field.h>'] + self.headers = ['<PyStencilsField.h>'] self.body = body @property diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 7c4937d1f..4a1352b49 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -6,8 +6,7 @@ import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter -from pystencils.astnodes import (DestructuringBindingsForFieldClass, - KernelFunction, Node) +from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import (PointerType, VectorType, address_of, cast_func, create_type, @@ -264,11 +263,12 @@ class CBackend: def _print_DestructuringBindingsForFieldClass(self, node: Node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined - destructuring_bindings = ["%s = %s.%s%s;" % - (u.name, + destructuring_bindings = ["%s %s = %s.%s;" % + (u.dtype, + u.name, u.field_name if hasattr(u, 'field_name') else u.field_names[0], - DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__], - "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate)) + node.CLASS_TO_MEMBER_DICT[u.__class__] % + (() if type(u) == FieldPointerSymbol else (u.coordinate,))) for u in undefined_field_symbols ] destructuring_bindings.sort() # only for code aesthetics diff --git a/pystencils/include/PyStencilsField.h b/pystencils/include/PyStencilsField.h new file mode 100644 index 000000000..3055cae23 --- /dev/null +++ b/pystencils/include/PyStencilsField.h @@ -0,0 +1,19 @@ +#pragma once + +extern "C++" { +#ifdef __CUDA_ARCH__ +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + DTYPE_T shape[DIMENSION]; + DTYPE_T stride[DIMENSION]; +}; +#else +#include <array> + +template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField { + DTYPE_T *data; + std::array<DTYPE_T, DIMENSION> shape; + std::array<DTYPE_T, DIMENSION> stride; +}; +#endif +} diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py index 248963ae3..ff3aae12e 100644 --- a/pystencils_tests/test_destructuring_field_class.py +++ b/pystencils_tests/test_destructuring_field_class.py @@ -8,9 +8,13 @@ """ import sympy +import jinja2 + import pystencils from pystencils.astnodes import DestructuringBindingsForFieldClass +from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol + def test_destructuring_field_class(): @@ -19,15 +23,39 @@ def test_destructuring_field_class(): normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) - ast = pystencils.create_kernel(normal_assignments) + ast = pystencils.create_kernel(normal_assignments, target='gpu') print(pystencils.show_code(ast)) ast.body = DestructuringBindingsForFieldClass(ast.body) print(pystencils.show_code(ast)) + ast.compile() + + +class DestructuringEmojiClass(DestructuringBindingsForFieldClass): + CLASS_TO_MEMBER_DICT = { + FieldPointerSymbol: "🥶", + FieldShapeSymbol: "😳_%i", + FieldStrideSymbol: "🥵_%i" + } + CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") + def __init__(self, node): + super().__init__(node) + self.headers = [] + + +def test_destructuring_alternative_field_class(): + z, x, y = pystencils.fields("z, y, x: [2d]") + normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment( + z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], []) + + ast = pystencils.create_kernel(normal_assignments, target='gpu') + ast.body = DestructuringEmojiClass(ast.body) + print(pystencils.show_code(ast)) def main(): test_destructuring_field_class() + test_destructuring_alternative_field_class() if __name__ == '__main__': -- GitLab