diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 59d9011d4e644401049260931ebb9d15f66cce43..ff96581b091ab3e2a21bc3c3964c1b0030de2266 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -4,16 +4,17 @@ from typing import Set import numpy as np import sympy as sp from sympy.core import S +from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.printing.ccode import C89CodePrinter + from pystencils.astnodes import KernelFunction, Node from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.data_types import ( - PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, - reinterpret_cast_func, vector_memory_access) + PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, + vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.integer_functions import ( - bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, - int_div, int_power_of_2, modulo_ceil) + bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) try: from sympy.printing.ccode import C99CodePrinter as CCodePrinter @@ -292,9 +293,9 @@ class CBackend: return "" def _print_Conditional(self, node): - if type(node.condition_expr) is sp.boolalg.BooleanTrue: + if type(node.condition_expr) is BooleanTrue: return self._print_Block(node.true_block) - elif type(node.condition_expr) is sp.boolalg.BooleanFalse: + elif type(node.condition_expr) is BooleanFalse: return self._print_Block(node.false_block) cond_type = get_type_of_expression(node.condition_expr) if isinstance(cond_type, VectorType): diff --git a/pystencils/cpu/vectorization.py b/pystencils/cpu/vectorization.py index 1e3a434bc0251c86074a5260f16f7cdf6ab53b58..f99c77f2d95f1d45191726101718beff7c9d3b0e 100644 --- a/pystencils/cpu/vectorization.py +++ b/pystencils/cpu/vectorization.py @@ -2,18 +2,17 @@ import warnings from typing import Container, Union import sympy as sp +from sympy.logic.boolalg import BooleanFunction import pystencils.astnodes as ast from pystencils.backends.simd_instruction_sets import get_vector_instruction_set from pystencils.data_types import ( - PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, - vector_memory_access) + PointerType, TypedSymbol, VectorType, cast_func, collate_types, get_type_of_expression, vector_memory_access) from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor from pystencils.sympyextensions import fast_subs -from pystencils.transformations import ( - cut_loop, filtered_tree_iteration, replace_inner_stride_with_one) +from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one # noinspection PyPep8Naming @@ -177,7 +176,7 @@ def insert_vector_casts(ast_node): visit_expr(expr.args[4])) elif isinstance(expr, cast_func): return expr - elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, sp.boolalg.BooleanFunction): + elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction): new_args = [visit_expr(a) for a in expr.args] arg_types = [get_type_of_expression(a) for a in new_args] if not any(type(t) is VectorType for t in arg_types): diff --git a/pystencils/data_types.py b/pystencils/data_types.py index af085ab3cc6ff33aa8f8e4edfd7d6b086c6d3f5a..5d553006bcca4c1cccbddb681d8710a92b75c437 100644 --- a/pystencils/data_types.py +++ b/pystencils/data_types.py @@ -4,14 +4,14 @@ from functools import partial from typing import Tuple import numpy as np - -import pystencils import sympy as sp import sympy.codegen.ast +from sympy.core.cache import cacheit +from sympy.logic.boolalg import Boolean, BooleanFunction + +import pystencils from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.utils import all_equal -from sympy.core.cache import cacheit -from sympy.logic.boolalg import Boolean try: import llvmlite.ir as ir @@ -541,7 +541,7 @@ def get_type_of_expression(expr, elif isinstance(expr, sp.Indexed): typed_symbol = expr.base.label return typed_symbol.dtype.base_type - elif isinstance(expr, (sp.boolalg.Boolean, sp.boolalg.BooleanFunction)): + elif isinstance(expr, (Boolean, BooleanFunction)): # if any arg is of vector type return a vector boolean, else return a normal scalar boolean result = create_type("bool") vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] diff --git a/pystencils/transformations.py b/pystencils/transformations.py index 762c36136cd7f3eb541a6075f4b24021813c82ad..2c279403e62fd4644b2202d3ee01d050781a34d5 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -8,14 +8,14 @@ from types import MappingProxyType import numpy as np import sympy as sp from sympy.core.numbers import ImaginaryUnit -from sympy.logic.boolalg import Boolean +from sympy.logic.boolalg import Boolean, BooleanFunction import pystencils.astnodes as ast import pystencils.integer_functions from pystencils.assignment import Assignment from pystencils.data_types import ( - PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, - get_base_type, get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) + PointerType, StructType, TypedImaginaryUnit, TypedSymbol, cast_func, collate_types, create_type, get_base_type, + get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) from pystencils.field import AbstractField, Field, FieldType from pystencils.kernelparameters import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection @@ -851,7 +851,7 @@ class KernelConstraintsCheck: return cast_func( self.process_expression(rhs.args[0], type_constants=False), rhs.dtype) - elif isinstance(rhs, sp.boolalg.BooleanFunction) or \ + elif isinstance(rhs, BooleanFunction) or \ type(rhs) in pystencils.integer_functions.__dict__.values(): new_args = [self.process_expression(a, type_constants) for a in rhs.args] types_of_expressions = [get_type_of_expression(a) for a in new_args] @@ -1030,7 +1030,7 @@ def insert_casts(node): types = [get_type_of_expression(arg) for arg in args] assert len(types) > 0 # Never ever, ever collate to float type for boolean functions! - target = collate_types(types, forbid_collation_to_float=isinstance(node.func, sp.boolalg.BooleanFunction)) + target = collate_types(types, forbid_collation_to_float=isinstance(node.func, BooleanFunction)) zipped = list(zip(args, types)) if target.func is PointerType: assert node.func is sp.Add