diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 55617313a648e5dd295d12d7b01de078442cbc3e..81127f66970728cc122e5f4c5bb1f8648b19bdf9 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp -from pystencils.data_types import TypedSymbol, cast_func, create_type, TypedImaginaryUnit +from pystencils.data_types import TypedImaginaryUnit, TypedSymbol, cast_func, create_type from pystencils.field import Field from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.sympyextensions import fast_subs @@ -537,7 +537,7 @@ class SympyAssignment(Node): loop_counters.add(LoopOverCoordinate.get_loop_counter_symbol(i)) result.update(loop_counters) result.update(self._lhs_symbol.atoms(sp.Symbol)) - result = { r for r in result if not isinstance(r, TypedImaginaryUnit)} + result = {r for r in result if not isinstance(r, TypedImaginaryUnit)} return result @property diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index e38b8fde8cd7f2a5769dc974ab93a5ca957bba4b..f2f478c6b040191a1e7f756ad0e82fd294f19764 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -212,9 +212,7 @@ class CBackend: method_name = "_print_" + cls.__name__ if hasattr(self, method_name): return getattr(self, method_name)(node) - raise NotImplementedError(self.__class__.__name__ + - " does not support node of type " + - node.__class__.__name__) + raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) def _print_Type(self, node): return str(node) @@ -428,7 +426,7 @@ class CustomSympyPrinter(CCodePrinter): elif expr.func in infix_functions: return "(%s %s %s)" % (self._print( expr.args[0]), infix_functions[expr.func], - self._print(expr.args[1])) + self._print(expr.args[1])) elif expr.func == int_power_of_2: return "(1 << (%s))" % (self._print(expr.args[0])) elif expr.func == int_div: @@ -444,7 +442,7 @@ class CustomSympyPrinter(CCodePrinter): elif dtype.numpy_dtype == np.float64: return res + '.0' if '.' not in res else res elif dtype.numpy_dtype == np.complex64: - return f"{self._typed_number(number.real, np.float32)} + {self._typed_number(number.real, np.float32).replace('f', 'if')}" + return f"{self._typed_number(number.real, np.float32)} + {self._typed_number(number.real, np.float32).replace('f', 'if')}" # noqa elif dtype.numpy_dtype == np.complex128: return f"{self._typed_number(number.real, np.float64)} + {self._typed_number(number.real, np.float64)}i" else: @@ -469,8 +467,7 @@ class CustomSympyPrinter(CCodePrinter): end=self._print(end), expr=self._print(expr.function), increment=str(1), - condition=self._print(var) + ' <= ' + - self._print(end) # if start < end else '>=' + condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' ) return code @@ -493,8 +490,7 @@ class CustomSympyPrinter(CCodePrinter): end=self._print(end), expr=self._print(expr.function), increment=str(1), - condition=self._print(var) + ' <= ' + - self._print(end) # if start < end else '>=' + condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' ) return code @@ -566,8 +562,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): return self.instruction_set['rsqrt'].format( self._print(expr.args[0])) else: - return "({})".format(self._print(1 / - sp.sqrt(expr.args[0]))) + return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) elif isinstance(expr, vec_any): expr_type = get_type_of_expression(expr.args[0]) if type(expr_type) is not VectorType: