From dbb91b95a659a0786415068edfd4f2093e133552 Mon Sep 17 00:00:00 2001
From: Your Name <stephan.seitz@fau.de>
Date: Wed, 28 Aug 2019 16:57:15 +0200
Subject: [PATCH] Add TypedImaginaryUnit

---
 pystencils/astnodes.py                   |  3 +-
 pystencils/backends/cbackend.py          | 14 ++++++--
 pystencils/data_types.py                 | 45 +++++++++++++++++++-----
 pystencils_tests/test_complex_numbers.py | 18 ++++++----
 4 files changed, 62 insertions(+), 18 deletions(-)

diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py
index b2413828..55617313 100644
--- a/pystencils/astnodes.py
+++ b/pystencils/astnodes.py
@@ -3,7 +3,7 @@ from typing import Any, List, Optional, Sequence, Set, Union
 
 import sympy as sp
 
-from pystencils.data_types import TypedSymbol, cast_func, create_type
+from pystencils.data_types import TypedSymbol, cast_func, create_type, TypedImaginaryUnit
 from pystencils.field import Field
 from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
 from pystencils.sympyextensions import fast_subs
@@ -537,6 +537,7 @@ class SympyAssignment(Node):
                     loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i))
         result.update(loop_counters)
         result.update(self._lhs_symbol.atoms(sp.Symbol))
+        result = { r for r in result if not isinstance(r, TypedImaginaryUnit)}
         return result
 
     @property
diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py
index 213df6f5..e38b8fde 100644
--- a/pystencils/backends/cbackend.py
+++ b/pystencils/backends/cbackend.py
@@ -5,6 +5,7 @@ import numpy as np
 import sympy as sp
 from sympy.core import S
 from sympy.printing.ccode import C89CodePrinter
+from sympy.printing.codeprinter import requires
 
 from pystencils.astnodes import KernelFunction, Node
 from pystencils.cpu.vectorization import vec_all, vec_any
@@ -17,7 +18,6 @@ from pystencils.integer_functions import (
     int_div, int_power_of_2, modulo_ceil)
 from pystencils.kernelparameters import FieldPointerSymbol
 
-from sympy.printing.codeprinter import requires
 try:
     from sympy.printing.ccode import C99CodePrinter as CCodePrinter
 except ImportError:
@@ -122,7 +122,8 @@ def get_headers(ast_node: Node) -> Set[str]:
     if isinstance(ast_node, KernelFunction):
         if any(
                 np.issubdtype(a.dtype.numpy_dtype, np.complexfloating)
-                for a in ast_node.atoms(sp.Symbol) if hasattr(a,'dtype') and hasattr(a.dtype, 'numpy_dtype')):
+                for a in ast_node.atoms(sp.Symbol)
+                if hasattr(a, 'dtype') and hasattr(a.dtype, 'numpy_dtype')):
             if ast_node.backend == 'c':
                 headers.update({"<complex>"})
 
@@ -510,6 +511,15 @@ class CustomSympyPrinter(CCodePrinter):
     def _print_ImaginaryUnit(self, expr):
         return "std::complex<double>{0,1}"
 
+    def _print_TypedImaginaryUnit(self, expr):
+        if expr.dtype.numpy_dtype == np.complex64:
+            return "std::complex<float>{0,1}"
+        elif expr.dtype.numpy_dtype == np.complex128:
+            return "std::complex<double>{0,1}"
+        else:
+            raise NotImplementedError(
+                "only complex64 and complex128 supported")
+
     def _print_Complex(self, expr):
         return self._typed_number(expr, np.complex64)
 
diff --git a/pystencils/data_types.py b/pystencils/data_types.py
index b52a9a1b..3060f0da 100644
--- a/pystencils/data_types.py
+++ b/pystencils/data_types.py
@@ -87,7 +87,8 @@ class cast_func(sp.Function):
     @property
     def is_integer(self):
         if hasattr(self.dtype, 'numpy_dtype'):
-            return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
+            return np.issubdtype(self.dtype.numpy_dtype,
+                                 np.integer) or super().is_integer
         else:
             return super().is_integer
 
@@ -368,21 +369,27 @@ def peel_off_type(dtype, type_to_peel_off):
     return dtype
 
 
-def collate_types(types, forbid_collation_to_complex=False, forbid_collation_to_float=False):
+def collate_types(types,
+                  forbid_collation_to_complex=False,
+                  forbid_collation_to_float=False):
     """
     Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
     Uses the collation rules from numpy.
     """
     if forbid_collation_to_complex:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
+        types = [
+            t for t in types
+            if not np.issubdtype(t.numpy_dtype, np.complexfloating)
+        ]
         if not types:
-            types = [ create_type(np.float64)]
+            types = [create_type(np.float64)]
 
     if forbid_collation_to_float:
-        types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
+        types = [
+            t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
+        ]
         if not types:
-            types = [ create_type(np.int64) ]
-
+            types = [create_type(np.int64)]
 
     # Pointer arithmetic case i.e. pointer + integer is allowed
     if any(type(t) is PointerType for t in types):
@@ -439,7 +446,7 @@ def get_type_of_expression(expr,
     expr = sp.sympify(expr)
     if isinstance(expr, sp.Integer):
         return create_type(default_int_type)
-    elif  expr.is_real == False:
+    elif expr.is_real == False:
         return create_type(default_complex_type)
     elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
         return create_type(default_float_type)
@@ -479,7 +486,9 @@ def get_type_of_expression(expr,
         expr: sp.Expr
         if expr.args:
             types = tuple(get_type(a) for a in expr.args)
-            return collate_types(types)
+            return collate_types(types
+                forbid_collation_to_complex=expr.is_real == True,
+                forbid_collation_to_float=expr.is_integer == True)
         else:
             if expr.is_integer:
                 return create_type(default_int_type)
@@ -724,3 +733,21 @@ class StructType:
 
     def __hash__(self):
         return hash((self.numpy_dtype, self.const))
+
+
+class TypedImaginaryUnit(TypedSymbol):
+    def __new__(cls, *args, **kwds):
+        obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
+        return obj
+
+    def __new_stage2__(cls, dtype, *args, **kwargs):
+        obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
+                                                      "_i",
+                                                      dtype,
+                                                      is_imaginary=True,
+                                                      *args,
+                                                      **kwargs)
+        return obj
+
+    __xnew__ = staticmethod(__new_stage2__)
+    __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py
index d1230f3a..98a5b70d 100644
--- a/pystencils_tests/test_complex_numbers.py
+++ b/pystencils_tests/test_complex_numbers.py
@@ -8,15 +8,14 @@
 """
 
 import itertools
-
+import numpy as np
 import pytest
 import sympy
 from sympy.functions import im, re
 
 import pystencils
 from pystencils import AssignmentCollection
-from pystencils.data_types import create_type, TypedSymbol
-
+from pystencils.data_types import TypedSymbol, create_type, TypedImaginaryUnit
 
 X, Y = pystencils.fields('x, y: complex64[2d]')
 A, B = pystencils.fields('a, b: float32[2d]')
@@ -49,7 +48,9 @@ SCALAR_DTYPES = ['float32', 'float64']
 @pytest.mark.parametrize("assignment, scalar_dtypes",
                          itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
 def test_complex_numbers(assignment, scalar_dtypes):
-    ast = pystencils.create_kernel(assignment,
+    ast = pystencils.create_kernel(assignment.subs(
+        {sympy.sympify(1j).args[1]:
+         TypedImaginaryUnit(create_type('complex64'))}),
                                    target='cpu',
                                    data_type=scalar_dtypes)
     code = str(pystencils.show_code(ast))
@@ -60,10 +61,11 @@ def test_complex_numbers(assignment, scalar_dtypes):
     kernel = ast.compile()
     assert kernel is not None
 
+
 X, Y = pystencils.fields('x, y: complex128[2d]')
 A, B = pystencils.fields('a, b: float64[2d]')
 S1, S2 = sympy.symbols('S1, S2')
-T128 = TypedSymbol('t', create_type('complex128'))
+T128 = TypedSymbol('ts', create_type('complex128'))
 
 TEST_ASSIGNMENTS = [
     AssignmentCollection({X[0, 0]: 1j}),
@@ -86,10 +88,14 @@ TEST_ASSIGNMENTS = [
 ]
 
 SCALAR_DTYPES = ['float32', 'float64']
+
+
 @pytest.mark.parametrize("assignment, scalar_dtypes",
                          itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
 def test_complex_numbers_64(assignment, scalar_dtypes):
-    ast = pystencils.create_kernel(assignment,
+    ast = pystencils.create_kernel(assignment.subs(
+        {sympy.sympify(1j).args[1]:
+         TypedImaginaryUnit(create_type('complex128'))}),
                                    target='cpu',
                                    data_type=scalar_dtypes)
     code = str(pystencils.show_code(ast))
-- 
GitLab