From dbb91b95a659a0786415068edfd4f2093e133552 Mon Sep 17 00:00:00 2001 From: Your Name <stephan.seitz@fau.de> Date: Wed, 28 Aug 2019 16:57:15 +0200 Subject: [PATCH] Add TypedImaginaryUnit --- pystencils/astnodes.py | 3 +- pystencils/backends/cbackend.py | 14 ++++++-- pystencils/data_types.py | 45 +++++++++++++++++++----- pystencils_tests/test_complex_numbers.py | 18 ++++++---- 4 files changed, 62 insertions(+), 18 deletions(-) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b2413828..55617313 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp -from pystencils.data_types import TypedSymbol, cast_func, create_type +from pystencils.data_types import TypedSymbol, cast_func, create_type, TypedImaginaryUnit from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.sympyextensions import fast_subs @@ -537,6 +537,7 @@ class SympyAssignment(Node): loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) result.update(loop_counters) result.update(self._lhs_symbol.atoms(sp.Symbol)) + result = { r for r in result if not isinstance(r, TypedImaginaryUnit)} return result @property diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 213df6f5..e38b8fde 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -5,6 +5,7 @@ import numpy as np import sympy as sp from sympy.core import S from sympy.printing.ccode import C89CodePrinter +from sympy.printing.codeprinter import requires from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any @@ -17,7 +18,6 @@ from pystencils.integer_functions import ( int_div, int_power_of_2, modulo_ceil) from pystencils.kernelparameters import FieldPointerSymbol -from sympy.printing.codeprinter import requires try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: @@ -122,7 +122,8 @@ def get_headers(ast_node: Node) -> Set[str]: if isinstance(ast_node, KernelFunction): if any( np.issubdtype(a.dtype.numpy_dtype, np.complexfloating) - for a in ast_node.atoms(sp.Symbol) if hasattr(a,'dtype') and hasattr(a.dtype, 'numpy_dtype')): + for a in ast_node.atoms(sp.Symbol) + if hasattr(a, 'dtype') and hasattr(a.dtype, 'numpy_dtype')): if ast_node.backend == 'c': headers.update({"<complex>"}) @@ -510,6 +511,15 @@ class CustomSympyPrinter(CCodePrinter): def _print_ImaginaryUnit(self, expr): return "std::complex<double>{0,1}" + def _print_TypedImaginaryUnit(self, expr): + if expr.dtype.numpy_dtype == np.complex64: + return "std::complex<float>{0,1}" + elif expr.dtype.numpy_dtype == np.complex128: + return "std::complex<double>{0,1}" + else: + raise NotImplementedError( + "only complex64 and complex128 supported") + def _print_Complex(self, expr): return self._typed_number(expr, np.complex64) diff --git a/pystencils/data_types.py b/pystencils/data_types.py index b52a9a1b..3060f0da 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -87,7 +87,8 @@ class cast_func(sp.Function): @property def is_integer(self): if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + return np.issubdtype(self.dtype.numpy_dtype, + np.integer) or super().is_integer else: return super().is_integer @@ -368,21 +369,27 @@ def peel_off_type(dtype, type_to_peel_off): return dtype -def collate_types(types, forbid_collation_to_complex=False, forbid_collation_to_float=False): +def collate_types(types, + forbid_collation_to_complex=False, + forbid_collation_to_float=False): """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. """ if forbid_collation_to_complex: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)] + types = [ + t for t in types + if not np.issubdtype(t.numpy_dtype, np.complexfloating) + ] if not types: - types = [ create_type(np.float64)] + types = [create_type(np.float64)] if forbid_collation_to_float: - types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)] + types = [ + t for t in types if not np.issubdtype(t.numpy_dtype, np.floating) + ] if not types: - types = [ create_type(np.int64) ] - + types = [create_type(np.int64)] # Pointer arithmetic case i.e. pointer + integer is allowed if any(type(t) is PointerType for t in types): @@ -439,7 +446,7 @@ def get_type_of_expression(expr, expr = sp.sympify(expr) if isinstance(expr, sp.Integer): return create_type(default_int_type) - elif expr.is_real == False: + elif expr.is_real == False: return create_type(default_complex_type) elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): return create_type(default_float_type) @@ -479,7 +486,9 @@ def get_type_of_expression(expr, expr: sp.Expr if expr.args: types = tuple(get_type(a) for a in expr.args) - return collate_types(types) + return collate_types(types + forbid_collation_to_complex=expr.is_real == True, + forbid_collation_to_float=expr.is_integer == True) else: if expr.is_integer: return create_type(default_int_type) @@ -724,3 +733,21 @@ class StructType: def __hash__(self): return hash((self.numpy_dtype, self.const)) + + +class TypedImaginaryUnit(TypedSymbol): + def __new__(cls, *args, **kwds): + obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds) + return obj + + def __new_stage2__(cls, dtype, *args, **kwargs): + obj = super(TypedImaginaryUnit, cls).__xnew__(cls, + "_i", + dtype, + is_imaginary=True, + *args, + **kwargs) + return obj + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py index d1230f3a..98a5b70d 100644 --- a/pystencils_tests/test_complex_numbers.py +++ b/pystencils_tests/test_complex_numbers.py @@ -8,15 +8,14 @@ """ import itertools - +import numpy as np import pytest import sympy from sympy.functions import im, re import pystencils from pystencils import AssignmentCollection -from pystencils.data_types import create_type, TypedSymbol - +from pystencils.data_types import TypedSymbol, create_type, TypedImaginaryUnit X, Y = pystencils.fields('x, y: complex64[2d]') A, B = pystencils.fields('a, b: float32[2d]') @@ -49,7 +48,9 @@ SCALAR_DTYPES = ['float32', 'float64'] @pytest.mark.parametrize("assignment, scalar_dtypes", itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) def test_complex_numbers(assignment, scalar_dtypes): - ast = pystencils.create_kernel(assignment, + ast = pystencils.create_kernel(assignment.subs( + {sympy.sympify(1j).args[1]: + TypedImaginaryUnit(create_type('complex64'))}), target='cpu', data_type=scalar_dtypes) code = str(pystencils.show_code(ast)) @@ -60,10 +61,11 @@ def test_complex_numbers(assignment, scalar_dtypes): kernel = ast.compile() assert kernel is not None + X, Y = pystencils.fields('x, y: complex128[2d]') A, B = pystencils.fields('a, b: float64[2d]') S1, S2 = sympy.symbols('S1, S2') -T128 = TypedSymbol('t', create_type('complex128')) +T128 = TypedSymbol('ts', create_type('complex128')) TEST_ASSIGNMENTS = [ AssignmentCollection({X[0, 0]: 1j}), @@ -86,10 +88,14 @@ TEST_ASSIGNMENTS = [ ] SCALAR_DTYPES = ['float32', 'float64'] + + @pytest.mark.parametrize("assignment, scalar_dtypes", itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) def test_complex_numbers_64(assignment, scalar_dtypes): - ast = pystencils.create_kernel(assignment, + ast = pystencils.create_kernel(assignment.subs( + {sympy.sympify(1j).args[1]: + TypedImaginaryUnit(create_type('complex128'))}), target='cpu', data_type=scalar_dtypes) code = str(pystencils.show_code(ast)) -- GitLab