From faed31109dbbdcc62c608dc605660f6c0e85a7ad Mon Sep 17 00:00:00 2001
From: Stephan Seitz <stephan.seitz@fau.de>
Date: Tue, 17 Nov 2020 15:37:17 +0100
Subject: [PATCH] (Optionally) deactivate unify_shape_symbols

---
 pystencils/astnodes.py               | 80 ++++++++++++++--------------
 pystencils/gpucuda/kernelcreation.py |  6 ++-
 pystencils/kernelcreation.py         |  4 +-
 3 files changed, 47 insertions(+), 43 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index 85c6033c..ca37e054 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -247,7 +247,7 @@ class KernelFunction(Node):
             if hasattr(symbol, 'field_name'):
                 return field_map[symbol.field_name],
             elif hasattr(symbol, 'field_names'):
-                return tuple(field_map[fn] for fn in symbol.field_names)
+                return tuple(field_map.get(fn, field_map.get('diff' + fn)) for fn in symbol.field_names)
             return ()
 
         argument_symbols = self._body.undefined_symbols - self.global_variables
@@ -297,7 +297,7 @@ class Block(Node):
             except AttributeError:
                 pass
 
-    @property
+    @ property
     def args(self):
         return self._nodes
 
@@ -361,7 +361,7 @@ class Block(Node):
             replacements.parent = self
             self._nodes.insert(idx, replacements)
 
-    @property
+    @ property
     def symbols_defined(self):
         result = set()
         for a in self.args:
@@ -371,7 +371,7 @@ class Block(Node):
                 result.update(a.symbols_defined)
         return result
 
-    @property
+    @ property
     def undefined_symbols(self):
         result = set()
         defined_symbols = set()
@@ -443,7 +443,7 @@ class LoopOverCoordinate(Node):
             self.step = fast_subs(self.step, subs_dict, skip)
         return self
 
-    @property
+    @ property
     def args(self):
         result = [self.body]
         for e in [self.start, self.stop, self.step]:
@@ -461,11 +461,11 @@ class LoopOverCoordinate(Node):
         elif child == self.stop:
             self.stop = replacement
 
-    @property
+    @ property
     def symbols_defined(self):
         return {self.loop_counter_symbol}
 
-    @property
+    @ property
     def undefined_symbols(self):
         result = self.body.undefined_symbols
         for possible_symbol in [self.start, self.stop, self.step]:
@@ -473,22 +473,22 @@ class LoopOverCoordinate(Node):
                 result.update(possible_symbol.atoms(sp.Symbol))
         return result - {self.loop_counter_symbol}
 
-    @staticmethod
+    @ staticmethod
     def get_loop_counter_name(coordinate_to_loop_over):
         return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
 
-    @staticmethod
+    @ staticmethod
     def get_block_loop_counter_name(coordinate_to_loop_over):
         return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}"
 
-    @property
+    @ property
     def loop_counter_name(self):
         if self.is_block_loop:
             return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
         else:
             return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
 
-    @staticmethod
+    @ staticmethod
     def is_loop_counter_symbol(symbol):
         prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX
         if not symbol.name.startswith(prefix):
@@ -498,29 +498,29 @@ class LoopOverCoordinate(Node):
         coordinate = int(symbol.name[len(prefix) + 1:])
         return coordinate
 
-    @staticmethod
+    @ staticmethod
     def get_loop_counter_symbol(coordinate_to_loop_over):
         return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
 
-    @staticmethod
+    @ staticmethod
     def get_block_loop_counter_symbol(coordinate_to_loop_over):
         return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
                            'int',
                            nonnegative=True)
 
-    @property
+    @ property
     def loop_counter_symbol(self):
         if self.is_block_loop:
             return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
         else:
             return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
 
-    @property
+    @ property
     def is_outermost_loop(self):
         from pystencils.transformations import get_next_parent_of_type
         return get_next_parent_of_type(self, LoopOverCoordinate) is None
 
-    @property
+    @ property
     def is_innermost_loop(self):
         return len(self.atoms(LoopOverCoordinate)) == 0
 
@@ -552,11 +552,11 @@ class SympyAssignment(Node):
             return False
         return True
 
-    @property
+    @ property
     def lhs(self):
         return self._lhs_symbol
 
-    @lhs.setter
+    @ lhs.setter
     def lhs(self, new_value):
         self._lhs_symbol = new_value
         self._is_declaration = self.__is_declaration()
@@ -572,17 +572,17 @@ class SympyAssignment(Node):
         except Exception:
             pass
 
-    @property
+    @ property
     def args(self):
         return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
 
-    @property
+    @ property
     def symbols_defined(self):
         if not self._is_declaration:
             return set()
         return {self._lhs_symbol}
 
-    @property
+    @ property
     def undefined_symbols(self):
         result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)}
         # Add loop counters if there a field accesses
@@ -596,11 +596,11 @@ class SympyAssignment(Node):
         result.update(self._lhs_symbol.atoms(sp.Symbol))
         return result
 
-    @property
+    @ property
     def is_declaration(self):
         return self._is_declaration
 
-    @property
+    @ property
     def is_const(self):
         return self._is_const
 
@@ -657,7 +657,7 @@ class ResolvedFieldAccess(sp.Indexed):
         super_class_contents = super(ResolvedFieldAccess, self)._hashable_content()
         return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field))
 
-    @property
+    @ property
     def typed_symbol(self):
         return self.base.label
 
@@ -687,18 +687,18 @@ class TemporaryMemoryAllocation(Node):
         self.headers = ['<stdlib.h>']
         self._align_offset = align_offset
 
-    @property
+    @ property
     def symbols_defined(self):
         return {self.symbol}
 
-    @property
+    @ property
     def undefined_symbols(self):
         if isinstance(self.size, sp.Basic):
             return self.size.atoms(sp.Symbol)
         else:
             return set()
 
-    @property
+    @ property
     def args(self):
         return [self.symbol]
 
@@ -714,22 +714,22 @@ class TemporaryMemoryFree(Node):
         super(TemporaryMemoryFree, self).__init__(parent=None)
         self.alloc_node = alloc_node
 
-    @property
+    @ property
     def symbol(self):
         return self.alloc_node.symbol
 
     def offset(self, byte_alignment):
         return self.alloc_node.offset(byte_alignment)
 
-    @property
+    @ property
     def symbols_defined(self):
         return set()
 
-    @property
+    @ property
     def undefined_symbols(self):
         return set()
 
-    @property
+    @ property
     def args(self):
         return []
 
@@ -747,15 +747,15 @@ class SourceCodeComment(Node):
     def __init__(self, text):
         self.text = text
 
-    @property
+    @ property
     def args(self):
         return []
 
-    @property
+    @ property
     def symbols_defined(self):
         return set()
 
-    @property
+    @ property
     def undefined_symbols(self):
         return set()
 
@@ -770,15 +770,15 @@ class EmptyLine(Node):
     def __init__(self):
         pass
 
-    @property
+    @ property
     def args(self):
         return []
 
-    @property
+    @ property
     def symbols_defined(self):
         return set()
 
-    @property
+    @ property
     def undefined_symbols(self):
         return set()
 
@@ -798,15 +798,15 @@ class ConditionalFieldAccess(sp.Function):
     def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
         return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
 
-    @property
+    @ property
     def access(self):
         return self.args[0]
 
-    @property
+    @ property
     def outofbounds_condition(self):
         return self.args[1]
 
-    @property
+    @ property
     def outofbounds_value(self):
         return self.args[2]
 
diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py
index 52a4dc8b..c7ad0c68 100644
--- a/pystencils/gpucuda/kernelcreation.py
+++ b/pystencils/gpucuda/kernelcreation.py
@@ -15,7 +15,8 @@ def create_cuda_kernel(assignments,
                        iteration_slice=None,
                        ghost_layers=None,
                        skip_independence_check=False,
-                       use_textures_for_interpolation=True):
+                       use_textures_for_interpolation=True,
+                       do_unify_shape_symbols=True):
     assert assignments, "Assignments must not be empty!"
     fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check)
     all_fields = fields_read.union(fields_written)
@@ -60,7 +61,8 @@ def create_cuda_kernel(assignments,
     block = Block(assignments)
 
     block = indexing.guard(block, common_shape)
-    unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
+    if do_unify_shape_symbols:
+        unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers)
 
     ast = KernelFunction(block,
                          'gpu',
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index b158754c..ee6cf8a5 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -30,6 +30,7 @@ def create_kernel(assignments,
                   use_textures_for_interpolation=True,
                   cpu_prepend_optimizations=[],
                   use_auto_for_assignments=False,
+                  do_unify_shape_symbols=True,
                   opencl_queue=None,
                   opencl_ctx=None):
     """
@@ -119,7 +120,8 @@ def create_kernel(assignments,
                                  indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params),
                                  iteration_slice=iteration_slice, ghost_layers=ghost_layers,
                                  skip_independence_check=skip_independence_check,
-                                 use_textures_for_interpolation=use_textures_for_interpolation)
+                                 use_textures_for_interpolation=use_textures_for_interpolation,
+                                 do_unify_shape_symbols=do_unify_shape_symbols)
         if target == 'opencl':
             from pystencils.opencl.opencljit import make_python_function
             ast._backend = 'opencl'
-- 
GitLab