From 53986db2d297238321aafa4f4c434261294ff756 Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Sat, 13 Jul 2019 01:17:28 +0200
Subject: [PATCH] Fixup for DestructuringBindingsForFieldClass

- rename header Field.h is not a unique name in waLBerla context
- add PyStencilsField.h
- bindings were lacking data type
---
 pystencils/astnodes.py                        |  8 ++---
 pystencils/backends/cbackend.py               | 12 ++++----
 pystencils/include/PyStencilsField.h          | 19 ++++++++++++
 .../test_destructuring_field_class.py         | 30 ++++++++++++++++++-
 4 files changed, 58 insertions(+), 11 deletions(-)
 create mode 100644 pystencils/include/PyStencilsField.h

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 2d3174a1a..83b12f4b0 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -653,10 +653,10 @@ class DestructuringBindingsForFieldClass(Node):
     """
     CLASS_TO_MEMBER_DICT = {
         FieldPointerSymbol: "data",
-        FieldShapeSymbol: "shape",
-        FieldStrideSymbol: "stride"
+        FieldShapeSymbol: "shape[%i]",
+        FieldStrideSymbol: "stride[%i]"
     }
-    CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
+    CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>")
 
     @property
     def fields_accessed(self) -> Set['ResolvedFieldAccess']:
@@ -665,7 +665,7 @@ class DestructuringBindingsForFieldClass(Node):
 
     def __init__(self, body):
         super(DestructuringBindingsForFieldClass, self).__init__()
-        self.headers = ['<Field.h>']
+        self.headers = ['<PyStencilsField.h>']
         self.body = body
 
     @property
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 7c4937d1f..4a1352b49 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -6,8 +6,7 @@ import sympy as sp
 from sympy.core import S
 from sympy.printing.ccode import C89CodePrinter
 
-from pystencils.astnodes import (DestructuringBindingsForFieldClass,
-                                 KernelFunction, Node)
+from pystencils.astnodes import KernelFunction, Node
 from pystencils.cpu.vectorization import vec_all, vec_any
 from pystencils.data_types import (PointerType, VectorType, address_of,
                                    cast_func, create_type,
@@ -264,11 +263,12 @@ class CBackend:
     def _print_DestructuringBindingsForFieldClass(self, node: Node):
         # Define all undefined symbols
         undefined_field_symbols = node.symbols_defined
-        destructuring_bindings = ["%s = %s.%s%s;" %
-                                  (u.name,
+        destructuring_bindings = ["%s %s = %s.%s;" %
+                                  (u.dtype,
+                                   u.name,
                                    u.field_name if hasattr(u, 'field_name') else u.field_names[0],
-                                   DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
-                                   "" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
+                                   node.CLASS_TO_MEMBER_DICT[u.__class__] %
+                                   (() if type(u) == FieldPointerSymbol else (u.coordinate,)))
                                   for u in undefined_field_symbols
                                   ]
         destructuring_bindings.sort()  # only for code aesthetics
diff --git a/pystencils/include/PyStencilsField.h b/pystencils/include/PyStencilsField.h
new file mode 100644
index 000000000..3055cae23
--- /dev/null
+++ b/pystencils/include/PyStencilsField.h
@@ -0,0 +1,19 @@
+#pragma once
+
+extern "C++" {
+#ifdef __CUDA_ARCH__
+template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
+  DTYPE_T *data;
+  DTYPE_T shape[DIMENSION];
+  DTYPE_T stride[DIMENSION];
+};
+#else
+#include <array>
+
+template <typename DTYPE_T, std::size_t DIMENSION> struct PyStencilsField {
+  DTYPE_T *data;
+  std::array<DTYPE_T, DIMENSION> shape;
+  std::array<DTYPE_T, DIMENSION> stride;
+};
+#endif
+}
diff --git a/pystencils_tests/test_destructuring_field_class.py b/pystencils_tests/test_destructuring_field_class.py
index 248963ae3..ff3aae12e 100644
--- a/pystencils_tests/test_destructuring_field_class.py
+++ b/pystencils_tests/test_destructuring_field_class.py
@@ -8,9 +8,13 @@
 
 """
 import sympy
+import jinja2
+
 
 import pystencils
 from pystencils.astnodes import DestructuringBindingsForFieldClass
+from pystencils.kernelparameters import  FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
+
 
 
 def test_destructuring_field_class():
@@ -19,15 +23,39 @@ def test_destructuring_field_class():
     normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
         z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
 
-    ast = pystencils.create_kernel(normal_assignments)
+    ast = pystencils.create_kernel(normal_assignments, target='gpu')
     print(pystencils.show_code(ast))
 
     ast.body = DestructuringBindingsForFieldClass(ast.body)
     print(pystencils.show_code(ast))
+    ast.compile()
+
+
+class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
+    CLASS_TO_MEMBER_DICT = {
+        FieldPointerSymbol: "🥶",
+        FieldShapeSymbol: "😳_%i",
+        FieldStrideSymbol: "🥵_%i"
+    }
+    CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>")
+    def __init__(self, node):
+        super().__init__(node)
+        self.headers = []
+        
+    
+def test_destructuring_alternative_field_class():
+    z, x, y = pystencils.fields("z, y, x: [2d]")
 
+    normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
+        z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
+
+    ast = pystencils.create_kernel(normal_assignments, target='gpu')
+    ast.body = DestructuringEmojiClass(ast.body)
+    print(pystencils.show_code(ast))
 
 def main():
     test_destructuring_field_class()
+    test_destructuring_alternative_field_class()
 
 
 if __name__ == '__main__':
-- 
GitLab