diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 4e7c5cf2445c07f79dddc2c7eba93c520e2b320c..213df6f559fd18e3ba451451569e240971cf2dbf 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -17,17 +17,24 @@ from pystencils.integer_functions import ( int_div, int_power_of_2, modulo_ceil) from pystencils.kernelparameters import FieldPointerSymbol +from sympy.printing.codeprinter import requires try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter except ImportError: from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 -__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] +__all__ = [ + 'generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', + 'CustomSympyPrinter' +] KERNCRAFT_NO_TERNARY_MODE = False -def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str: +def generate_c(ast_node: Node, + signature_only: bool = False, + dialect='c', + custom_backend=None) -> str: """Prints an abstract syntax tree node as C or CUDA code. This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded @@ -76,14 +83,28 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom def get_global_declarations(ast): global_declarations = [] + # global_declarations.append( + # CustomCodeNode( + # "using namespace std::complex_literals;", set(), + # set())) def visit_node(sub_ast): + nonlocal global_declarations if hasattr(sub_ast, "required_global_declarations"): - nonlocal global_declarations global_declarations += sub_ast.required_global_declarations if hasattr(sub_ast, "args"): for node in sub_ast.args: visit_node(node) + if isinstance(sub_ast, KernelFunction): + if any( + np.issubdtype(a.dtype.numpy_dtype, np.complexfloating) + for a in sub_ast.atoms(sp.Expr) if hasattr(a, 'dtype') + and hasattr(a.dtype, 'numpy_dtype')): + if sub_ast.backend == 'cpu': + global_declarations.append( + CustomCodeNode( + "using namespace std::complex_literals;", set(), + set())) visit_node(ast) @@ -94,9 +115,17 @@ def get_headers(ast_node: Node) -> Set[str]: """Return a set of header files, necessary to compile the printed C-like code.""" headers = set() + headers.update({'"complex_helper.hpp"'}) if isinstance(ast_node, KernelFunction) and ast_node.instruction_set: headers.update(ast_node.instruction_set['headers']) + if isinstance(ast_node, KernelFunction): + if any( + np.issubdtype(a.dtype.numpy_dtype, np.complexfloating) + for a in ast_node.atoms(sp.Symbol) if hasattr(a,'dtype') and hasattr(a.dtype, 'numpy_dtype')): + if ast_node.backend == 'c': + headers.update({"<complex>"}) + if hasattr(ast_node, 'headers'): headers.update(ast_node.headers) for a in ast_node.args: @@ -136,8 +165,11 @@ class CustomCodeNode(Node): class PrintNode(CustomCodeNode): # noinspection SpellCheckingInspection def __init__(self, symbol_to_print): - code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) - super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) + code = '\nstd::cout << "%s = " << %s << std::endl; \n' % ( + symbol_to_print.name, symbol_to_print.name) + super(PrintNode, self).__init__(code, + symbols_read=[symbol_to_print], + symbols_defined=set()) self.headers.append("<iostream>") @@ -146,11 +178,15 @@ class PrintNode(CustomCodeNode): # noinspection PyPep8Naming class CBackend: - - def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): + def __init__(self, + sympy_printer=None, + signature_only=False, + vector_instruction_set=None, + dialect='c'): if sympy_printer is None: if vector_instruction_set is not None: - self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) + self.sympy_printer = VectorizedCustomSympyPrinter( + vector_instruction_set) else: self.sympy_printer = CustomSympyPrinter() else: @@ -175,20 +211,25 @@ class CBackend: method_name = "_print_" + cls.__name__ if hasattr(self, method_name): return getattr(self, method_name)(node) - raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) + raise NotImplementedError(self.__class__.__name__ + + " does not support node of type " + + node.__class__.__name__) def _print_Type(self, node): return str(node) def _print_KernelFunction(self, node): - function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] + function_arguments = [ + "%s %s" % (self._print(s.symbol.dtype), s.symbol.name) + for s in node.get_parameters() + ] launch_bounds = "" if self._dialect == 'cuda': max_threads = node.indexing.max_threads_per_block() if max_threads: launch_bounds = "__launch_bounds__({}) ".format(max_threads) - func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name, - ", ".join(function_arguments)) + func_declaration = "FUNC_PREFIX %svoid %s(%s)" % ( + launch_bounds, node.function_name, ", ".join(function_arguments)) if self._signatureOnly: return func_declaration @@ -197,16 +238,22 @@ class CBackend: def _print_Block(self, node): block_contents = "\n".join([self._print(child) for child in node.args]) - return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) + return "{\n%s\n}" % ( + self._indent + self._indent.join(block_contents.splitlines(True))) def _print_PragmaBlock(self, node): return "%s\n%s" % (node.pragma_line, self._print_Block(node)) def _print_LoopOverCoordinate(self, node): counter_symbol = node.loop_counter_name - start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) - condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) - update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) + start = "int %s = %s" % (counter_symbol, + self.sympy_printer.doprint(node.start)) + condition = "%s < %s" % (counter_symbol, + self.sympy_printer.doprint(node.stop)) + update = "%s += %s" % ( + counter_symbol, + self.sympy_printer.doprint(node.step), + ) loop_str = "for (%s; %s; %s)" % (start, condition, update) prefix = "\n".join(node.prefix_lines) @@ -221,11 +268,13 @@ class CBackend: else: prefix = '' data_type = prefix + self._print(node.lhs.dtype) + " " - return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), + return "%s%s = %s;" % (data_type, + self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) else: lhs_type = get_type_of_expression(node.lhs) - if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): + if type(lhs_type) is VectorType and isinstance( + node.lhs, cast_func): arg, data_type, aligned, nontemporal = node.lhs.args instr = 'storeU' if aligned: @@ -237,10 +286,12 @@ class CBackend: else: rhs = node.rhs - return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), - self.sympy_printer.doprint(rhs)) + ';' + return self._vector_instruction_set[instr].format( + "&" + self.sympy_printer.doprint(node.lhs.args[0]), + self.sympy_printer.doprint(rhs)) + ';' else: - return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) + return "%s = %s;" % (self.sympy_printer.doprint( + node.lhs), self.sympy_printer.doprint(node.rhs)) def _print_TemporaryMemoryAllocation(self, node): align = 64 @@ -256,7 +307,8 @@ class CBackend: def _print_TemporaryMemoryFree(self, node): align = 64 - return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) + return "free(%s - %d);" % (self.sympy_printer.doprint( + node.symbol.name), node.offset(align)) def _print_SkipIteration(self, _): return "continue;" @@ -267,7 +319,9 @@ class CBackend: def _print_Conditional(self, node): cond_type = get_type_of_expression(node.condition_expr) if isinstance(cond_type, VectorType): - raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") + raise ValueError( + "Problem with Conditional inside vectorized loop - use vec_any or vec_all" + ) condition_expr = self.sympy_printer.doprint(node.condition_expr) true_block = self._print_Block(node.true_block) result = "if (%s)\n%s " % (condition_expr, true_block) @@ -279,14 +333,13 @@ class CBackend: def _print_DestructuringBindingsForFieldClass(self, node): # Define all undefined symbols undefined_field_symbols = node.symbols_defined - destructuring_bindings = ["%s %s = %s.%s;" % - (u.dtype, - u.name, - u.field_name if hasattr(u, 'field_name') else u.field_names[0], - node.CLASS_TO_MEMBER_DICT[u.__class__] % - (() if type(u) == FieldPointerSymbol else (u.coordinate,))) - for u in undefined_field_symbols - ] + destructuring_bindings = [ + "%s %s = %s.%s;" % + (u.dtype, u.name, u.field_name if hasattr(u, 'field_name') else + u.field_names[0], node.CLASS_TO_MEMBER_DICT[u.__class__] % + (() if type(u) == FieldPointerSymbol else (u.coordinate, ))) + for u in undefined_field_symbols + ] destructuring_bindings.sort() # only for code aesthetics return "{\n" + self._indent + \ ("\n" + self._indent).join(destructuring_bindings) + \ @@ -300,7 +353,6 @@ class CBackend: # noinspection PyPep8Naming class CustomSympyPrinter(CCodePrinter): - def __init__(self): super(CustomSympyPrinter, self).__init__() self._float_type = create_type("float32") @@ -312,12 +364,16 @@ class CustomSympyPrinter(CCodePrinter): def _print_Pow(self, expr): """Don't use std::pow function, for small integer exponents, write as multiplication""" if not expr.free_symbols: - return self._typed_number(expr.evalf(), get_type_of_expression(expr)) + return self._typed_number(expr.evalf(), + get_type_of_expression(expr)) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" - elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: - return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) + return "(" + self._print( + sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0: + return "1 / ({})".format( + self._print(sp.Mul(*[expr.base] * (-expr.exp), + evaluate=False))) else: return super(CustomSympyPrinter, self)._print_Pow(expr) @@ -328,7 +384,8 @@ class CustomSympyPrinter(CCodePrinter): def _print_Equality(self, expr): """Equality operator is not printable in default printer""" - return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))' + return '((' + self._print(expr.lhs) + ") == (" + self._print( + expr.rhs) + '))' def _print_Piecewise(self, expr): """Print piecewise in one line (remove newlines)""" @@ -347,9 +404,11 @@ class CustomSympyPrinter(CCodePrinter): return expr.to_c(self._print) if isinstance(expr, reinterpret_cast_func): arg, data_type = expr.args - return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) + return "*((%s)(& %s))" % (PointerType( + data_type, restrict=False), self._print(arg)) elif isinstance(expr, address_of): - assert len(expr.args) == 1, "address_of must only have one argument" + assert len( + expr.args) == 1, "address_of must only have one argument" return "&(%s)" % self._print(expr.args[0]) elif isinstance(expr, cast_func): arg, data_type = expr.args @@ -366,11 +425,14 @@ class CustomSympyPrinter(CCodePrinter): elif isinstance(expr, fast_inv_sqrt): return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif expr.func in infix_functions: - return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) + return "(%s %s %s)" % (self._print( + expr.args[0]), infix_functions[expr.func], + self._print(expr.args[1])) elif expr.func == int_power_of_2: return "(1 << (%s))" % (self._print(expr.args[0])) elif expr.func == int_div: - return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1])) + return "((%s) / (%s))" % (self._print( + expr.args[0]), self._print(expr.args[1])) else: return super(CustomSympyPrinter, self)._print_Function(expr) @@ -381,9 +443,9 @@ class CustomSympyPrinter(CCodePrinter): elif dtype.numpy_dtype == np.float64: return res + '.0' if '.' not in res else res elif dtype.numpy_dtype == np.complex64: - return f"{_typed_number(number.real, np.float32)} + {_typed_number(number.real, np.float32).replace('f', 'if')}" + return f"{self._typed_number(number.real, np.float32)} + {self._typed_number(number.real, np.float32).replace('f', 'if')}" elif dtype.numpy_dtype == np.complex128: - return f"{_typed_number(number.real, np.float64)} + {_typed_number(number.real, np.float64)}i" + return f"{self._typed_number(number.real, np.float64)} + {self._typed_number(number.real, np.float64)}i" else: return res @@ -406,7 +468,8 @@ class CustomSympyPrinter(CCodePrinter): end=self._print(end), expr=self._print(expr.function), increment=str(1), - condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' + condition=self._print(var) + ' <= ' + + self._print(end) # if start < end else '>=' ) return code @@ -429,12 +492,15 @@ class CustomSympyPrinter(CCodePrinter): end=self._print(end), expr=self._print(expr.function), increment=str(1), - condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' + condition=self._print(var) + ' <= ' + + self._print(end) # if start < end else '>=' ) return code + _print_Max = C89CodePrinter._print_Max _print_Min = C89CodePrinter._print_Min + @requires(headers={'stdbool.h'}) def _print_re(self, expr): return f"std::real({self._print(expr.args[0])})" @@ -442,8 +508,11 @@ class CustomSympyPrinter(CCodePrinter): return f"std::imag({self._print(expr.args[0])})" def _print_ImaginaryUnit(self, expr): - return "1if" - + return "std::complex<double>{0,1}" + + def _print_Complex(self, expr): + return self._typed_number(expr, np.complex64) + # noinspection PyPep8Naming class VectorizedCustomSympyPrinter(CustomSympyPrinter): @@ -456,7 +525,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): def _scalarFallback(self, func_name, expr, *args, **kwargs): expr_type = get_type_of_expression(expr) if type(expr_type) is not VectorType: - return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) + return getattr(super(VectorizedCustomSympyPrinter, self), + func_name)(expr, *args, **kwargs) else: assert self.instruction_set['width'] == expr_type.width return None @@ -464,7 +534,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): def _print_Function(self, expr): if isinstance(expr, vector_memory_access): arg, data_type, aligned, _ = expr.args - instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] + instruction = self.instruction_set[ + 'loadA'] if aligned else self.instruction_set['loadU'] return instruction.format("& " + self._print(arg)) elif isinstance(expr, cast_func): arg, data_type = expr.args @@ -473,7 +544,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): elif expr.func == fast_division: result = self._scalarFallback('_print_Function', expr) if not result: - result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) + result = self.instruction_set['/'].format( + self._print(expr.args[0]), self._print(expr.args[1])) return result elif expr.func == fast_sqrt: return "({})".format(self._print(sp.sqrt(expr.args[0]))) @@ -481,21 +553,25 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._scalarFallback('_print_Function', expr) if not result: if self.instruction_set['rsqrt']: - return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) + return self.instruction_set['rsqrt'].format( + self._print(expr.args[0])) else: - return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) + return "({})".format(self._print(1 / + sp.sqrt(expr.args[0]))) elif isinstance(expr, vec_any): expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: - return self.instruction_set['any'].format(self._print(expr.args[0])) + return self.instruction_set['any'].format( + self._print(expr.args[0])) elif isinstance(expr, vec_all): expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: return self._print(expr.args[0]) else: - return self.instruction_set['all'].format(self._print(expr.args[0])) + return self.instruction_set['all'].format( + self._print(expr.args[0])) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) @@ -545,7 +621,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): assert len(summands) >= 2 processed = summands[0].term for summand in summands[1:]: - func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] + func = self.instruction_set[ + '-'] if summand.sign == -1 else self.instruction_set['+'] processed = func.format(processed, summand.term) return processed @@ -557,18 +634,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): one = self.instruction_set['makeVec'].format(1.0) if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: - return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" + return "(" + self._print( + sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" elif expr.exp == -1: one = self.instruction_set['makeVec'].format(1.0) - return self.instruction_set['/'].format(one, self._print(expr.base)) + return self.instruction_set['/'].format(one, + self._print(expr.base)) elif expr.exp == 0.5: return self.instruction_set['sqrt'].format(self._print(expr.base)) elif expr.exp == -0.5: root = self.instruction_set['sqrt'].format(self._print(expr.base)) return self.instruction_set['/'].format(one, root) - elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: - return self.instruction_set['/'].format(one, - self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) + elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0: + return self.instruction_set['/'].format( + one, + self._print(sp.Mul(*[expr.base] * (-expr.exp), + evaluate=False))) else: raise ValueError("Generic exponential not supported: " + str(expr)) @@ -612,14 +693,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): if len(b) > 0: denominator_str = b_str[0] for item in b_str[1:]: - denominator_str = self.instruction_set['*'].format(denominator_str, item) + denominator_str = self.instruction_set['*'].format( + denominator_str, item) result = self.instruction_set['/'].format(result, denominator_str) if inside_add: return sign, result else: if sign < 0: - return self.instruction_set['*'].format(self._print(S.NegativeOne), result) + return self.instruction_set['*'].format( + self._print(S.NegativeOne), result) else: return result @@ -627,13 +710,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._scalarFallback('_print_Relational', expr) if result: return result - return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) + return self.instruction_set[expr.rel_op].format( + self._print(expr.lhs), self._print(expr.rhs)) def _print_Equality(self, expr): result = self._scalarFallback('_print_Equality', expr) if result: return result - return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) + return self.instruction_set['=='].format(self._print(expr.lhs), + self._print(expr.rhs)) def _print_Piecewise(self, expr): result = self._scalarFallback('_print_Piecewise', expr) @@ -651,13 +736,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): result = self._print(expr.args[-1][0]) for true_expr, condition in reversed(expr.args[:-1]): - if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): + if isinstance(condition, cast_func) and get_type_of_expression( + condition.args[0]) == create_type("bool"): if not KERNCRAFT_NO_TERNARY_MODE: - result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), - result) + result = "(({}) ? ({}) : ({}))".format( + self._print(condition.args[0]), self._print(true_expr), + result) else: print("Warning - skipping ternary op") else: # noinspection SpellCheckingInspection - result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) + result = self.instruction_set['blendv'].format( + result, self._print(true_expr), self._print(condition)) + return result return result diff --git a/pystencils/data_types.py b/pystencils/data_types.py index a0c6efc2a8abb4a496aa1f4cecd6e39ea937c3a7..b52a9a1bb4459f522effbe4effe174338232c620 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -123,7 +123,7 @@ class boolean_cast_func(cast_func, Boolean): # noinspection PyPep8Naming class vector_memory_access(cast_func): - nargs = (4,) + nargs = (4, ) # noinspection PyPep8Naming @@ -172,7 +172,8 @@ class TypedSymbol(sp.Symbol): @property def is_integer(self): if hasattr(self.dtype, 'numpy_dtype'): - return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer + return np.issubdtype(self.dtype.numpy_dtype, + np.integer) or super().is_integer else: return super().is_integer @@ -292,12 +293,10 @@ to_ctypes.map = { np.dtype(np.int16): ctypes.c_int16, np.dtype(np.int32): ctypes.c_int32, np.dtype(np.int64): ctypes.c_int64, - np.dtype(np.uint8): ctypes.c_uint8, np.dtype(np.uint16): ctypes.c_uint16, np.dtype(np.uint32): ctypes.c_uint32, np.dtype(np.uint64): ctypes.c_uint64, - np.dtype(np.float32): ctypes.c_float, np.dtype(np.float64): ctypes.c_double, } @@ -330,7 +329,8 @@ def ctypes_from_llvm(data_type): elif isinstance(data_type, ir.VoidType): return None # Void type is not supported by ctypes else: - raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type)) + raise NotImplementedError('Data type %s of %s is not supported yet' % + (type(data_type), data_type)) def to_llvm_type(data_type): @@ -353,12 +353,10 @@ if ir: np.dtype(np.int16): ir.IntType(16), np.dtype(np.int32): ir.IntType(32), np.dtype(np.int64): ir.IntType(64), - np.dtype(np.uint8): ir.IntType(8), np.dtype(np.uint16): ir.IntType(16), np.dtype(np.uint32): ir.IntType(32), np.dtype(np.uint64): ir.IntType(64), - np.dtype(np.float32): ir.FloatType(), np.dtype(np.float64): ir.DoubleType(), } @@ -370,11 +368,21 @@ def peel_off_type(dtype, type_to_peel_off): return dtype -def collate_types(types): +def collate_types(types, forbid_collation_to_complex=False, forbid_collation_to_float=False): """ Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Uses the collation rules from numpy. """ + if forbid_collation_to_complex: + types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)] + if not types: + types = [ create_type(np.float64)] + + if forbid_collation_to_float: + types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)] + if not types: + types = [ create_type(np.int64) ] + # Pointer arithmetic case i.e. pointer + integer is allowed if any(type(t) is PointerType for t in types): @@ -382,7 +390,8 @@ def collate_types(types): for t in types: if type(t) is PointerType: if pointer_type is not None: - raise ValueError("Cannot collate the combination of two pointer types") + raise ValueError( + "Cannot collate the combination of two pointer types") pointer_type = t elif type(t) is BasicType: if not (t.is_int() or t.is_uint()): @@ -394,7 +403,8 @@ def collate_types(types): # peel of vector types, if at least one vector type occurred the result will also be the vector type vector_type = [t for t in types if type(t) is VectorType] if not all_equal(t.width for t in vector_type): - raise ValueError("Collation failed because of vector types with different width") + raise ValueError( + "Collation failed because of vector types with different width") types = [peel_off_type(t, VectorType) for t in types] # now we should have a list of basic types - struct types are not yet supported @@ -429,6 +439,8 @@ def get_type_of_expression(expr, expr = sp.sympify(expr) if isinstance(expr, sp.Integer): return create_type(default_int_type) + elif expr.is_real == False: + return create_type(default_complex_type) elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): return create_type(default_float_type) elif isinstance(expr, ResolvedFieldAccess): @@ -453,7 +465,8 @@ def get_type_of_expression(expr, elif isinstance(expr, sp.Indexed): typed_symbol = expr.base.label return typed_symbol.dtype.base_type - elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): + elif isinstance(expr, sp.boolalg.Boolean) or isinstance( + expr, sp.boolalg.BooleanFunction): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = create_type("bool") vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] @@ -499,14 +512,15 @@ class BasicType(Type): return 'std::complex<double>' elif name.startswith('int'): width = int(name[len("int"):]) - return "int%d_t" % (width,) + return "int%d_t" % (width, ) elif name.startswith('uint'): width = int(name[len("uint"):]) - return "uint%d_t" % (width,) + return "uint%d_t" % (width, ) elif name == 'bool': return 'bool' else: - raise NotImplementedError("Can map numpy to C name for %s" % (name,)) + raise NotImplementedError("Can map numpy to C name for %s" % + (name, )) def __init__(self, dtype, const=False): self.const = const @@ -534,7 +548,8 @@ class BasicType(Type): return 1 def is_int(self): - return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint'] + return self.numpy_dtype in np.sctypes[ + 'int'] or self.numpy_dtype in np.sctypes['uint'] def is_float(self): return self.numpy_dtype in np.sctypes['float'] @@ -565,7 +580,8 @@ class BasicType(Type): if not isinstance(other, BasicType): return False else: - return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + return (self.numpy_dtype, self.const) == (other.numpy_dtype, + other.const) def __hash__(self): return hash(str(self)) @@ -590,7 +606,8 @@ class VectorType(Type): if not isinstance(other, VectorType): return False else: - return (self.base_type, self.width) == (other.base_type, other.width) + return (self.base_type, self.width) == (other.base_type, + other.width) def __str__(self): if self.instruction_set is None: @@ -639,7 +656,9 @@ class PointerType(Type): if not isinstance(other, PointerType): return False else: - return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) + return (self.base_type, self.const, + self.restrict) == (other.base_type, other.const, + other.restrict) def __str__(self): components = [str(self.base_type), '*'] @@ -690,7 +709,8 @@ class StructType: if not isinstance(other, StructType): return False else: - return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) + return (self.numpy_dtype, self.const) == (other.numpy_dtype, + other.const) def __str__(self): # structs are handled byte-wise diff --git a/pystencils/include/complex_helper.hpp b/pystencils/include/complex_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..92b8f7c5454fa08d2b82d2ddf56a431ecb06a65d --- /dev/null +++ b/pystencils/include/complex_helper.hpp @@ -0,0 +1,66 @@ +/* + * complex_helper.hpp + * Copyright (C) 2019 Stephan Seitz <stephan.seitz@fau.de> + * + * Distributed under terms of the GPLv3 license. + */ + +#pragma once + +#include <complex> + +template <class U, class V> +auto operator*(const std::complex<U> &complexNumber, const V &scalar) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) * scalar, + std::imag(complexNumber) * scalar}; +} + +template <class U, class V> +auto operator*(const V &scalar, const std::complex<U> &complexNumber) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) * scalar, + std::imag(complexNumber) * scalar}; +} + +template <class U, class V> +auto operator+(const std::complex<U> &complexNumber, const V &scalar) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) + scalar, + std::imag(complexNumber)}; +} + +template <class U, class V> +auto operator+(const V &scalar, const std::complex<U> &complexNumber) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) + scalar, + std::imag(complexNumber)}; +} + +template <class U, class V> +auto operator-(const std::complex<U> &complexNumber, const V &scalar) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) - scalar, + std::imag(complexNumber)}; +} + +template <class U, class V> +auto operator-(const V &scalar, const std::complex<U> &complexNumber) + -> std::complex<U> { + return std::complex<U>{scalar - std::real(complexNumber), + std::imag(complexNumber)}; +} + +template <class U, class V> +auto operator/(const std::complex<U> &complexNumber, const V &scalar) + -> std::complex<U> { + return std::complex<U>{std::real(complexNumber) / scalar, + std::imag(complexNumber) / scalar}; +} + +template <class U, class V> +auto operator/(const V &scalar, const std::complex<U> &complexNumber) + -> std::complex<U> { + return std::complex<U>{scalar, 0} / complexNumber; +} + diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 554b9cf9b2257e6b0ef97b24239ea210baa8b63b..2aa5f1603bfe3310666c2ec55fb8b09a720243cc 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -887,7 +887,7 @@ class KernelConstraintsCheck: def process_assignment(self, assignment): # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 new_rhs = self.process_expression(assignment.rhs) - new_lhs = self._process_lhs(assignment.lhs) + new_lhs = self._process_lhs(assignment.lhs, assignment.rhs) return ast.SympyAssignment(new_lhs, new_rhs) def process_expression(self, rhs, type_constants=True): @@ -933,7 +933,7 @@ class KernelConstraintsCheck: def fields_written(self): return set(k.field for k, v in self._field_writes.items() if len(v)) - def _process_lhs(self, lhs): + def _process_lhs(self, lhs, rhs): assert isinstance(lhs, sp.Symbol) self._update_accesses_lhs(lhs) if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): diff --git a/pystencils_tests/test_complex_numbers.py b/pystencils_tests/test_complex_numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..d1230f3a31ab53346befa1f9db62359f34a7f606 --- /dev/null +++ b/pystencils_tests/test_complex_numbers.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# +# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de> +# +# Distributed under terms of the GPLv3 license. +""" + +""" + +import itertools + +import pytest +import sympy +from sympy.functions import im, re + +import pystencils +from pystencils import AssignmentCollection +from pystencils.data_types import create_type, TypedSymbol + + +X, Y = pystencils.fields('x, y: complex64[2d]') +A, B = pystencils.fields('a, b: float32[2d]') +S1, S2 = sympy.symbols('S1, S2') +T64 = TypedSymbol('t', create_type('complex64')) + +TEST_ASSIGNMENTS = [ + AssignmentCollection({X[0, 0]: 1j}), + AssignmentCollection({ + S1: re(Y.center), + S2: im(Y.center), + X[0, 0]: 2j * S1 + S2 + }), + AssignmentCollection({ + A.center: re(Y.center), + B.center: im(Y.center), + }), + AssignmentCollection({ + Y.center: re(Y.center) + X.center + 2j, + }), + AssignmentCollection({ + T64: 2 + 4j, + Y.center: X.center / T64, + }) +] + +SCALAR_DTYPES = ['float32', 'float64'] + + +@pytest.mark.parametrize("assignment, scalar_dtypes", + itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) +def test_complex_numbers(assignment, scalar_dtypes): + ast = pystencils.create_kernel(assignment, + target='cpu', + data_type=scalar_dtypes) + code = str(pystencils.show_code(ast)) + + print(code) + assert "Not supported" not in code + + kernel = ast.compile() + assert kernel is not None + +X, Y = pystencils.fields('x, y: complex128[2d]') +A, B = pystencils.fields('a, b: float64[2d]') +S1, S2 = sympy.symbols('S1, S2') +T128 = TypedSymbol('t', create_type('complex128')) + +TEST_ASSIGNMENTS = [ + AssignmentCollection({X[0, 0]: 1j}), + AssignmentCollection({ + S1: re(Y.center), + S2: im(Y.center), + X[0, 0]: 2j * S1 + S2 + }), + AssignmentCollection({ + A.center: re(Y.center), + B.center: im(Y.center), + }), + AssignmentCollection({ + Y.center: re(Y.center) + X.center + 2j, + }), + AssignmentCollection({ + T128: 2 + 4j, + Y.center: X.center / T128, + }) +] + +SCALAR_DTYPES = ['float32', 'float64'] +@pytest.mark.parametrize("assignment, scalar_dtypes", + itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES)) +def test_complex_numbers_64(assignment, scalar_dtypes): + ast = pystencils.create_kernel(assignment, + target='cpu', + data_type=scalar_dtypes) + code = str(pystencils.show_code(ast)) + + print(code) + assert "Not supported" not in code + + kernel = ast.compile() + assert kernel is not None