diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 7d9a501a249379c15b5891fd034b71af3184d428..0f2485fe903af25e3c93777e055dae80b2b4209d 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -3,6 +3,8 @@ from functools import reduce from operator import add, mul, sub, truediv import sympy as sp +import sympy.core.relational +import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment from ...sympyextensions import Assignment, AssignmentCollection, integer_functions @@ -59,14 +61,14 @@ class FreezeError(Exception): ExprLike = ( sp.Expr | sp.Tuple - | sp.core.relational.Relational - | sp.logic.boolalg.BooleanFunction + | sympy.core.relational.Relational + | sympy.logic.boolalg.BooleanFunction ) _ExprLike = ( sp.Expr, sp.Tuple, - sp.core.relational.Relational, - sp.logic.boolalg.BooleanFunction, + sympy.core.relational.Relational, + sympy.logic.boolalg.BooleanFunction, ) @@ -400,7 +402,7 @@ class FreezeExpressions: def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) - def map_Relational(self, rel: sp.core.relational.Relational) -> PsRel: + def map_Relational(self, rel: sympy.core.relational.Relational) -> PsRel: arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] match rel.rel_op: # type: ignore case "==": @@ -418,14 +420,14 @@ class FreezeExpressions: case other: raise FreezeError(f"Unsupported relation: {other}") - def map_And(self, conj: sp.logic.And) -> PsAnd: + def map_And(self, conj: sympy.logic.And) -> PsAnd: arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] return PsAnd(arg1, arg2) - def map_Or(self, conj: sp.logic.Or) -> PsOr: - arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] + def map_Or(self, disj: sympy.logic.Or) -> PsOr: + arg1, arg2 = [self.visit_expr(arg) for arg in disj.args] return PsOr(arg1, arg2) - def map_Not(self, conj: sp.Not) -> PsNot: - arg = self.visit_expr(conj.args[0]) + def map_Not(self, neg: sympy.logic.Not) -> PsNot: + arg = self.visit_expr(neg.args[0]) return PsNot(arg)