diff --git a/pystencils/data_types.py b/pystencils/data_types.py index 81f373e079477bace01d77f0231e51ad914a74aa..09bf9a57b80846ecab8181b9589513374e1e9e05 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -431,11 +431,15 @@ def collate_types(types, def get_type_of_expression(expr, default_float_type='double', default_int_type='int', - default_complex_type='complex128', symbol_type_dict=None): from pystencils.astnodes import ResolvedFieldAccess from pystencils.cpu.vectorization import vec_all, vec_any + # TODO: determine more general + if default_float_type == 'double' or default_float_type == 'float64': + default_complex_type = 'complex128' + else: + default_complex_type = 'complex64' if not symbol_type_dict: symbol_type_dict = defaultdict(lambda: create_type('double')) @@ -443,7 +447,6 @@ def get_type_of_expression(expr, get_type = partial(get_type_of_expression, default_float_type=default_float_type, default_int_type=default_int_type, - default_complex_type=default_complex_type, symbol_type_dict=symbol_type_dict) expr = sp.sympify(expr) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 2aa5f1603bfe3310666c2ec55fb8b09a720243cc..8469fc79a91268387aab90a8111609eead62d533 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -12,8 +12,8 @@ from sympy.logic.boolalg import Boolean import pystencils.astnodes as ast from pystencils.assignment import Assignment from pystencils.data_types import ( - PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type, - get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) + PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, + get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) from pystencils.field import AbstractField, Field, FieldType from pystencils.kernelparameters import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection @@ -898,6 +898,11 @@ class KernelConstraintsCheck: return rhs elif isinstance(rhs, TypedSymbol): return rhs + elif isinstance(rhs, sp.numbers.ImaginaryUnit): + return TypedImaginaryUnit(self._type_for_symbol['_ImaginaryUnit']) + elif isinstance(rhs, sp.Symbol): + return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) + return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) elif isinstance(rhs, sp.Symbol): return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) elif type_constants and isinstance(rhs, np.generic): @@ -1167,6 +1172,11 @@ def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='i dictionary, mapping symbol name to type """ result = defaultdict(lambda: default_type) + if default_type == 'double' or default_type == 'float64': # todo: fix + result['_ImaginaryUnit'] = create_type('complex128') + else: + result['_ImaginaryUnit'] = create_type('complex64') + for eq in eqs: if isinstance(eq, ast.Conditional): result.update(typing_from_sympy_inspection(eq.true_block.args)) diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py index d1ac2bf7907ecc1afa6c9e5b3d66fcd865e6e4e7..5a161914d51423e37476f3ffab0e50a6df7435d1 100644 --- a/pystencils_tests/test_complex_numbers.py +++ b/pystencils_tests/test_complex_numbers.py @@ -20,7 +20,8 @@ from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, create_type X, Y = pystencils.fields('x, y: complex64[2d]') A, B = pystencils.fields('a, b: float32[2d]') S1, S2 = sympy.symbols('S1, S2') -T64 = TypedSymbol('t', create_type('complex64')) +# T64 = TypedSymbol('t', create_type('complex64')) +T64 = sympy.Symbol('t') TEST_ASSIGNMENTS = [ AssignmentCollection({X[0, 0]: 1j}), @@ -48,11 +49,9 @@ SCALAR_DTYPES = ['float32', 'float64'] @pytest.mark.parametrize("assignment, scalar_dtypes", itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) def test_complex_numbers(assignment, scalar_dtypes): - ast = pystencils.create_kernel(assignment.subs( - {sympy.sympify(1j).args[1]: - TypedImaginaryUnit(create_type('complex64'))}), - target='cpu', - data_type=scalar_dtypes) + ast = pystencils.create_kernel(assignment, + target='cpu', + data_type='float32') code = str(pystencils.show_code(ast)) print(code) @@ -94,11 +93,9 @@ SCALAR_DTYPES = ['float32', 'float64'] @pytest.mark.parametrize("assignment, scalar_dtypes", itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) def test_complex_numbers_64(assignment, scalar_dtypes): - ast = pystencils.create_kernel(assignment.subs( - {sympy.sympify(1j).args[1]: - TypedImaginaryUnit(create_type('complex128'))}), - target='cpu', - data_type=scalar_dtypes) + ast = pystencils.create_kernel(assignment, + target='cpu', + data_type='double') code = str(pystencils.show_code(ast)) print(code) @@ -113,5 +110,8 @@ def test_get_data_type(): from pystencils.data_types import get_type_of_expression i = TypedImaginaryUnit(create_type('complex128')) - # assert get_type_of_expression(i+3).numpy_dtype == np.complex128 + assert get_type_of_expression(i+3).numpy_dtype == np.complex128 assert get_type_of_expression(i+3.).numpy_dtype == np.complex128 + i = TypedImaginaryUnit(create_type('complex64')) + assert get_type_of_expression(i+3, default_float_type='float32').numpy_dtype == np.complex64 + assert get_type_of_expression(i+3., default_float_type='float32').numpy_dtype == np.complex64