From c21d4621ea558dacd8e9a8395ed0d7114006a547 Mon Sep 17 00:00:00 2001
From: Markus Holzer <markus.holzer@fau.de>
Date: Tue, 29 Aug 2023 10:28:59 +0200
Subject: [PATCH] Distinguish between SymPy and pystencils Assignement better

---
 pystencils/__init__.py                        |  4 +-
 pystencils/assignment.py                      | 13 ++-
 pystencils/astnodes.py                        | 23 ++++-
 pystencils/backends/cbackend.py               | 18 ++--
 pystencils/cpu/kernelcreation.py              |  2 +-
 pystencils/kernelcreation.py                  | 15 ++--
 pystencils/node_collection.py                 | 83 +++++++++----------
 pystencils/transformations.py                 |  2 +-
 pystencils/typing/leaf_typing.py              | 13 ++-
 pystencils/typing/transformations.py          | 10 ++-
 pystencils_tests/test_augmented_assignment.py | 35 ++++++++
 pystencils_tests/test_modulo.py               |  2 +-
 12 files changed, 139 insertions(+), 81 deletions(-)
 create mode 100644 pystencils_tests/test_augmented_assignment.py

diff --git a/pystencils/__init__.py b/pystencils/__init__.py
index 92fdda9c5..0003a8b9a 100644
--- a/pystencils/__init__.py
+++ b/pystencils/__init__.py
@@ -2,7 +2,7 @@
 from .enums import Backend, Target
 from . import fd
 from . import stencil as stencil
-from .assignment import Assignment, assignment_from_stencil
+from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
 from .typing.typed_sympy import TypedSymbol
 from .display_utils import get_code_obj, get_code_str, show_code, to_dot
 from .field import Field, FieldType, fields
@@ -24,7 +24,7 @@ __all__ = ['Field', 'FieldType', 'fields',
            'Target', 'Backend',
            'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
            'AssignmentCollection',
-           'Assignment',
+           'Assignment', 'AddAugmentedAssignment',
            'assignment_from_stencil',
            'SymbolCreator',
            'create_data_handling',
diff --git a/pystencils/assignment.py b/pystencils/assignment.py
index c3ae4b436..d0e849954 100644
--- a/pystencils/assignment.py
+++ b/pystencils/assignment.py
@@ -1,20 +1,22 @@
 import numpy as np
 import sympy as sp
-from sympy.codegen.ast import Assignment
+from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
 from sympy.printing.latex import LatexPrinter
 
-__all__ = ['Assignment', 'assignment_from_stencil']
+__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
 
 
 def print_assignment_latex(printer, expr):
+    binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
     """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
     printed_lhs = printer.doprint(expr.lhs)
     printed_rhs = printer.doprint(expr.rhs)
-    return fr"{printed_lhs} \leftarrow {printed_rhs}"
+    return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
 
 
 def assignment_str(assignment):
-    return fr"{assignment.lhs} ← {assignment.rhs}"
+    op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←'
+    return fr"{assignment.lhs} {op} {assignment.rhs}"
 
 
 _old_new = sp.codegen.ast.Assignment.__new__
@@ -32,6 +34,9 @@ Assignment.__str__ = assignment_str
 Assignment.__new__ = _Assignment__new__
 LatexPrinter._print_Assignment = print_assignment_latex
 
+AugmentedAssignment.__str__ = assignment_str
+LatexPrinter._print_AugmentedAssignment = print_assignment_latex
+
 sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
 
 
diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index c9d66ae26..0c49160c8 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -561,10 +561,10 @@ class SympyAssignment(Node):
     def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
         super(SympyAssignment, self).__init__(parent=None)
         self._lhs_symbol = sp.sympify(lhs_symbol)
-        self.rhs = sp.sympify(rhs_expr)
+        self._rhs = sp.sympify(rhs_expr)
         self._is_const = is_const
         self._is_declaration = self.__is_declaration()
-        self.use_auto = use_auto
+        self._use_auto = use_auto
 
     def __is_declaration(self):
         from pystencils.typing import CastFunc
@@ -578,15 +578,28 @@ class SympyAssignment(Node):
     def lhs(self):
         return self._lhs_symbol
 
+    @property
+    def rhs(self):
+        return self._rhs
+
     @lhs.setter
     def lhs(self, new_value):
         self._lhs_symbol = new_value
         self._is_declaration = self.__is_declaration()
 
+    @rhs.setter
+    def rhs(self, new_rhs_expr):
+        self._rhs = new_rhs_expr
+
     def subs(self, subs_dict):
         self.lhs = fast_subs(self.lhs, subs_dict)
         self.rhs = fast_subs(self.rhs, subs_dict)
 
+    def fast_subs(self, subs_dict, skip=None):
+        self.lhs = fast_subs(self.lhs, subs_dict, skip)
+        self.rhs = fast_subs(self.rhs, subs_dict, skip)
+        return self
+
     def optimize(self, optimizations):
         try:
             from sympy.codegen.rewriting import optimize
@@ -596,7 +609,7 @@ class SympyAssignment(Node):
 
     @property
     def args(self):
-        return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
+        return [self._lhs_symbol, self.rhs]
 
     @property
     def symbols_defined(self):
@@ -627,6 +640,10 @@ class SympyAssignment(Node):
     def is_const(self):
         return self._is_const
 
+    @property
+    def use_auto(self):
+        return self._use_auto
+
     def replace(self, child, replacement):
         if child == self.lhs:
             replacement.parent = self
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 7665f9dfc..cc1de06c0 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -262,19 +262,17 @@ class CBackend:
         return f"{prefix}{loop_str}\n{self._print(node.body)}"
 
     def _print_SympyAssignment(self, node):
+        printed_lhs = self.sympy_printer.doprint(node.lhs)
+        printed_rhs = self.sympy_printer.doprint(node.rhs)
+
         if node.is_declaration:
             if node.use_auto:
-                data_type = 'auto '
+                data_type = 'auto'
             else:
+                data_type = self._print(node.lhs.dtype).replace(' const', '')
                 if node.is_const:
-                    prefix = 'const '
-                else:
-                    prefix = ''
-                data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " "
-
-            return "%s%s = %s;" % (data_type,
-                                   self.sympy_printer.doprint(node.lhs),
-                                   self.sympy_printer.doprint(node.rhs))
+                    data_type = f'const {data_type}'
+            return f"{data_type} {printed_lhs} = {printed_rhs};"
         else:
             lhs_type = get_type_of_expression(node.lhs)  # TOOD: this should have been typed
             printed_mask = ""
@@ -350,7 +348,7 @@ class CBackend:
                     code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}"
                 return pre_code + code
             else:
-                return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};"
+                return f"{printed_lhs} = {printed_rhs};"
 
     def _print_NontemporalFence(self, _):
         if 'streamFence' in self._vector_instruction_set:
diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py
index 4cf0955a5..c93d5ed72 100644
--- a/pystencils/cpu/kernelcreation.py
+++ b/pystencils/cpu/kernelcreation.py
@@ -18,7 +18,7 @@ from pystencils.transformations import (
     resolve_field_accesses, split_inner_loop)
 
 
-def create_kernel(assignments: Union[AssignmentCollection, NodeCollection],
+def create_kernel(assignments: Union[NodeCollection],
                   config: CreateKernelConfig) -> KernelFunction:
     """Creates an abstract syntax tree for a kernel function, by taking a list of update rules.
 
diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py
index 6e13bcfbd..385f42d2f 100644
--- a/pystencils/kernelcreation.py
+++ b/pystencils/kernelcreation.py
@@ -5,7 +5,7 @@ from typing import Union, List
 import sympy as sp
 from pystencils.config import CreateKernelConfig
 
-from pystencils.assignment import Assignment
+from pystencils.assignment import Assignment, AddAugmentedAssignment
 from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment
 from pystencils.cpu.vectorization import vectorize
 from pystencils.enums import Target, Backend
@@ -19,7 +19,10 @@ from pystencils.transformations import (
     loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel)
 
 
-def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCollection, List[Node], NodeCollection], *,
+def create_kernel(assignments: Union[Assignment, List[Assignment],
+                                     AddAugmentedAssignment, List[AddAugmentedAssignment],
+                                     AssignmentCollection, List[Node], NodeCollection],
+                  *,
                   config: CreateKernelConfig = None, **kwargs):
     """
     Creates abstract syntax tree (AST) of kernel, using a list of update equations.
@@ -59,7 +62,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
             setattr(config, k, v)
 
     # ----  Normalizing parameters
-    if isinstance(assignments, Assignment):
+    if isinstance(assignments, (Assignment, AddAugmentedAssignment)):
         assignments = [assignments]
     assert assignments, "Assignments must not be empty!"
     if isinstance(assignments, list):
@@ -86,13 +89,13 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
 
 def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
     """
-    Creates abstract syntax tree (AST) of kernel, using a list of update equations.
+    Creates abstract syntax tree (AST) of kernel, using a NodeCollection.
 
     Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields`
     to create_kernel
 
     Args:
-        assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection`
+        assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
         config: CreateKernelConfig which includes the needed configuration
 
     Returns:
@@ -187,7 +190,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo
     to create_kernel
 
     Args:
-        assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection`
+        assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed
         config: CreateKernelConfig which includes the needed configuration
 
     Returns:
diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py
index 227e1a10d..352406566 100644
--- a/pystencils/node_collection.py
+++ b/pystencils/node_collection.py
@@ -1,8 +1,9 @@
-from typing import List, Union
+from collections.abc import Iterable
+from typing import Any, Dict, List, Union, Optional, Set
 
 import sympy
 import sympy as sp
-from sympy.codegen import Assignment
+from sympy.codegen.ast import Assignment, AddAugmentedAssignment
 from sympy.codegen.rewriting import ReplaceOptim, optimize
 
 from pystencils.astnodes import Block, Node, SympyAssignment
@@ -12,33 +13,32 @@ from pystencils.simp import AssignmentCollection
 
 
 class NodeCollection:
-    def __init__(self, assignments: List[Union[Node, Assignment]]):
-        self.all_assignments = assignments
-
-        if all((isinstance(a, Assignment) for a in assignments)):
-            self.is_Nodes = False
-            self.is_Assignments = True
-        elif all((isinstance(n, Node) for n in assignments)):
-            self.is_Nodes = True
-            self.is_Assignments = False
-        else:
-            raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" '
-                             f'or a list of "pystencils.astnodes.Node')
+    def __init__(self, assignments: List[Union[Node, Assignment]],
+                 simplification_hints: Optional[Dict[str, Any]] = None,
+                 bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
+        nodes = list()
+        assignments = [assignments, ] if not isinstance(assignments, Iterable) else assignments
+        for assignment in assignments:
+            if isinstance(assignment, Assignment):
+                nodes.append(SympyAssignment(assignment.lhs, assignment.rhs))
+            elif isinstance(assignment, AddAugmentedAssignment):
+                nodes.append(SympyAssignment(assignment.lhs, assignment.lhs + assignment.rhs))
+            elif isinstance(assignment, Node):
+                nodes.append(assignment)
+            else:
+                raise ValueError(f"Unknown node in the AssignmentCollection: {assignment}")
 
-        self.simplification_hints = {}
+        self.all_assignments = nodes
+        self.simplification_hints = simplification_hints if simplification_hints else {}
+        self.bound_fields = bound_fields if bound_fields else {}
+        self.rhs_fields = rhs_fields if rhs_fields else {}
 
     @staticmethod
     def from_assignment_collection(assignment_collection: AssignmentCollection):
-        nodes = list()
-        for assignemt in assignment_collection.all_assignments:
-            if isinstance(assignemt, Assignment):
-                nodes.append(SympyAssignment(assignemt.lhs, assignemt.rhs))
-            elif isinstance(assignemt, Node):
-                nodes.append(assignemt)
-            else:
-                raise ValueError(f"Unknown node in the AssignmentCollection: {assignemt}")
-
-        return NodeCollection(nodes)
+        return NodeCollection(assignments=assignment_collection.all_assignments,
+                              simplification_hints=assignment_collection.simplification_hints,
+                              bound_fields=assignment_collection.bound_fields,
+                              rhs_fields=assignment_collection.rhs_fields)
 
     def evaluate_terms(self):
         evaluate_constant_terms = ReplaceOptim(
@@ -54,21 +54,20 @@ class NodeCollection:
         )
         sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
 
-        if self.is_Nodes:
-            def visitor(node):
-                if isinstance(node, CustomCodeNode):
-                    return node
-                elif isinstance(node, Block):
-                    return node.func([visitor(child) for child in node.args])
-                elif isinstance(node, Node):
-                    return node.func(*[visitor(child) for child in node.args])
-                elif isinstance(node, sympy.Basic):
-                    return optimize(node, sympy_optimisations)
-                else:
-                    raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
+        def visitor(node):
+            if isinstance(node, CustomCodeNode):
+                return node
+            elif isinstance(node, Block):
+                return node.func([visitor(child) for child in node.args])
+            elif isinstance(node, SympyAssignment):
+                new_lhs = visitor(node.lhs)
+                new_rhs = visitor(node.rhs)
+                return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
+            elif isinstance(node, Node):
+                return node.func(*[visitor(child) for child in node.args])
+            elif isinstance(node, sympy.Basic):
+                return optimize(node, sympy_optimisations)
+            else:
+                raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
 
-            self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
-        else:
-            self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
-                                    if hasattr(a, 'lhs')
-                                    else a for a in self.all_assignments]
+        self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
diff --git a/pystencils/transformations.py b/pystencils/transformations.py
index e07d871e9..5cde907b5 100644
--- a/pystencils/transformations.py
+++ b/pystencils/transformations.py
@@ -520,7 +520,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None,
                 coord_dict = create_coordinate_dict(group)
                 new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
                 if new_ptr not in enclosing_block.symbols_defined:
-                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
+                    new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
                     enclosing_block.insert_before(new_assignment, sympy_assignment)
                 last_pointer = new_ptr
 
diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py
index ecb82bab8..6c30a6abf 100644
--- a/pystencils/typing/leaf_typing.py
+++ b/pystencils/typing/leaf_typing.py
@@ -10,7 +10,6 @@ from sympy.core.relational import Relational
 from sympy.functions.elementary.piecewise import ExprCondPair
 from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
 from sympy.functions.elementary.hyperbolic import HyperbolicFunction
-from sympy.codegen import Assignment
 from sympy.logic.boolalg import BooleanFunction
 from sympy.logic.boolalg import BooleanAtom
 
@@ -51,7 +50,7 @@ class TypeAdder:
     def visit(self, obj):
         if isinstance(obj, (list, tuple)):
             return [self.visit(e) for e in obj]
-        if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
+        if isinstance(obj, ast.SympyAssignment):
             return self.process_assignment(obj)
         elif isinstance(obj, ast.Conditional):
             condition, condition_type = self.figure_out_type(obj.condition_expr)
@@ -67,7 +66,7 @@ class TypeAdder:
         else:
             raise ValueError("Invalid object in kernel " + str(type(obj)))
 
-    def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment:
+    def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment:
         # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
         new_rhs, rhs_type = self.figure_out_type(assignment.rhs)
 
@@ -81,11 +80,11 @@ class TypeAdder:
         assert isinstance(new_lhs, (Field.Access, TypedSymbol))
 
         if lhs_type != rhs_type:
-            logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
-                            f'rhs: "{new_rhs}" of type "{rhs_type}".')
-            return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type))
+            logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype '
+                          f'rhs: "{new_rhs}" of type "{rhs_type}".')
+            return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto)
         else:
-            return ast.SympyAssignment(new_lhs, new_rhs)
+            return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto)
 
     # Type System Specification
     # - Defined Types: TypedSymbol, Field, Field.Access, ...?
diff --git a/pystencils/typing/transformations.py b/pystencils/typing/transformations.py
index 74ecf19f1..43e69eb28 100644
--- a/pystencils/typing/transformations.py
+++ b/pystencils/typing/transformations.py
@@ -1,17 +1,19 @@
 from typing import List
 
+from pystencils.astnodes import Node
 from pystencils.config import CreateKernelConfig
 from pystencils.typing.leaf_typing import TypeAdder
-from sympy.codegen import Assignment
 
 
-def add_types(eqs: List[Assignment], config: CreateKernelConfig):
+def add_types(node_list: List[Node], config: CreateKernelConfig):
     """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
+    The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from
+    `pystencils.astnodes.Node`
 
     Additionally returns sets of all fields which are read/written
 
     Args:
-        eqs: list of equations
+        node_list: List of pystencils Nodes.
         config: CreateKernelConfig
 
     Returns:
@@ -22,4 +24,4 @@ def add_types(eqs: List[Assignment], config: CreateKernelConfig):
                       default_number_float=config.default_number_float,
                       default_number_int=config.default_number_int)
 
-    return check.visit(eqs)
+    return check.visit(node_list)
diff --git a/pystencils_tests/test_augmented_assignment.py b/pystencils_tests/test_augmented_assignment.py
new file mode 100644
index 000000000..43fa7e8e1
--- /dev/null
+++ b/pystencils_tests/test_augmented_assignment.py
@@ -0,0 +1,35 @@
+import pytest
+import pystencils as ps
+
+
+@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
+def test_add_augmented_assignment(target):
+    if target == ps.Target.GPU:
+        pytest.importorskip("cupy")
+
+    domain_size = (5, 5)
+    dh = ps.create_data_handling(domain_size=domain_size, periodicity=True, default_target=target)
+
+    f = dh.add_array("f", values_per_cell=1)
+    dh.fill(f.name, 0.0)
+
+    g = dh.add_array("g", values_per_cell=1)
+    dh.fill(g.name, 1.0)
+
+    up = ps.AddAugmentedAssignment(f.center, g.center)
+
+    config = ps.CreateKernelConfig(target=dh.default_target)
+    ast = ps.create_kernel(up, config=config)
+
+    kernel = ast.compile()
+    for i in range(10):
+        dh.run_kernel(kernel)
+
+    if target == ps.Target.GPU:
+        dh.all_to_cpu()
+
+    result = dh.gather_array(f.name)
+
+    for x in range(domain_size[0]):
+        for y in range(domain_size[1]):
+            assert result[x, y] == 10
diff --git a/pystencils_tests/test_modulo.py b/pystencils_tests/test_modulo.py
index 959daddb9..5a32acf5c 100644
--- a/pystencils_tests/test_modulo.py
+++ b/pystencils_tests/test_modulo.py
@@ -10,7 +10,7 @@ from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAss
 def test_mod(target, iteration_slice):
     if target == ps.Target.GPU:
         pytest.importorskip("cupy")
-    dh = ps.create_data_handling(domain_size=(5, 5), periodicity=True, default_target=ps.Target.CPU)
+    dh = ps.create_data_handling(domain_size=(5, 5), periodicity=True, default_target=target)
 
     loop_ctrs = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dh.dim)]
     cond = [sp.Eq(sp.Mod(loop_ctrs[i], 2), 1) for i in range(dh.dim)]
-- 
GitLab