diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 7c743a3997071a2f1515dd5802ee6f69cb741375..aacddfc8c1d3e7340df6aa0764f2fe03983f0610 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -385,6 +385,24 @@ class PsCall(PsExpression): return super().structurally_equal(other) and self._function == other._function +class PsNumericOpTrait: + """Trait for operations valid only on numerical types""" + + pass + + +class PsIntOpTrait: + """Trait for operations valid only on integer types""" + + pass + + +class PsBoolOpTrait: + """Trait for boolean operations""" + + pass + + class PsUnOp(PsExpression): __match_args__ = ("operand",) @@ -415,7 +433,7 @@ class PsUnOp(PsExpression): return None -class PsNeg(PsUnOp): +class PsNeg(PsUnOp, PsNumericOpTrait): @property def python_operator(self): return operator.neg @@ -503,31 +521,31 @@ class PsBinOp(PsExpression): return None -class PsAdd(PsBinOp): +class PsAdd(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.add -class PsSub(PsBinOp): +class PsSub(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.sub -class PsMul(PsBinOp): +class PsMul(PsBinOp, PsNumericOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.mul -class PsDiv(PsBinOp): +class PsDiv(PsBinOp, PsNumericOpTrait): # python_operator not implemented because can't unambigously decide # between intdiv and truediv pass -class PsIntDiv(PsBinOp): +class PsIntDiv(PsBinOp, PsIntOpTrait): """C-like integer division (round to zero).""" # python_operator not implemented because both floordiv and truediv have @@ -535,36 +553,94 @@ class PsIntDiv(PsBinOp): pass -class PsLeftShift(PsBinOp): +class PsLeftShift(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.lshift -class PsRightShift(PsBinOp): +class PsRightShift(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.rshift -class PsBitwiseAnd(PsBinOp): +class PsBitwiseAnd(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.and_ -class PsBitwiseXor(PsBinOp): +class PsBitwiseXor(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.xor -class PsBitwiseOr(PsBinOp): +class PsBitwiseOr(PsBinOp, PsIntOpTrait): @property def python_operator(self) -> Callable[[Any, Any], Any] | None: return operator.or_ +class PsAnd(PsBinOp, PsBoolOpTrait): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.and_ + + +class PsOr(PsBinOp, PsBoolOpTrait): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.and_ + + +class PsNot(PsUnOp, PsBoolOpTrait): + @property + def python_operator(self) -> Callable[[Any], Any] | None: + return operator.not_ + + +class PsRel(PsBinOp): + """Base class for binary relational operators""" + + +class PsEq(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.eq + + +class PsNe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.ne + + +class PsGe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.ge + + +class PsLe(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.le + + +class PsGt(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.gt + + +class PsLt(PsRel): + @property + def python_operator(self) -> Callable[[Any, Any], Any] | None: + return operator.lt + + class PsArrayInitList(PsExpression): __match_args__ = ("items",) diff --git a/src/pystencils/backend/ast/logical_expressions.py b/src/pystencils/backend/ast/logical_expressions.py deleted file mode 100644 index 2d739e020c2261a7d0b6fe917172223f1495c0e3..0000000000000000000000000000000000000000 --- a/src/pystencils/backend/ast/logical_expressions.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Callable, Any -import operator - -from .expressions import PsExpression -from .astnode import PsAstNode -from .util import failing_cast - - -class PsLogicalExpression(PsExpression): - __match_args__ = ("operand1", "operand2") - - def __init__(self, op1: PsExpression, op2: PsExpression): - super().__init__() - self._op1 = op1 - self._op2 = op2 - - @property - def operand1(self) -> PsExpression: - return self._op1 - - @operand1.setter - def operand1(self, expr: PsExpression): - self._op1 = expr - - @property - def operand2(self) -> PsExpression: - return self._op2 - - @operand2.setter - def operand2(self, expr: PsExpression): - self._op2 = expr - - def clone(self): - return type(self)(self._op1.clone(), self._op2.clone()) - - def get_children(self) -> tuple[PsAstNode, ...]: - return self._op1, self._op2 - - def set_child(self, idx: int, c: PsAstNode): - idx = [0, 1][idx] - match idx: - case 0: - self._op1 = failing_cast(PsExpression, c) - case 1: - self._op2 = failing_cast(PsExpression, c) - - def __repr__(self) -> str: - opname = self.__class__.__name__ - return f"{opname}({repr(self._op1)}, {repr(self._op2)})" - - @property - def python_operator(self) -> None | Callable[[Any, Any], Any]: - return None - - -class PsAnd(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.and_ - - -class PsEq(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.eq - - -class PsGe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.ge - - -class PsGt(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.gt - - -class PsLe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.le - - -class PsLt(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.lt - - -class PsNe(PsLogicalExpression): - @property - def python_operator(self) -> Callable[[Any, Any], Any] | None: - return operator.ne diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index b9bbe8cce84dca53a02e2297e0e8cb5199a25b26..c2334f54c34d476207eddc5466b2b13bff0d39d8 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -119,17 +119,19 @@ class AstFactory: body, ) - def loop_nest(self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock) -> PsLoop: + def loop_nest( + self, counters: Sequence[str], slices: Sequence[slice], body: PsBlock + ) -> PsLoop: """Create a loop nest from a sequence of slices. **Example:** This snippet creates a 3D loop nest with ten iterations in each dimension:: - + >>> from pystencils import make_slice >>> ctx = KernelCreationContext() >>> factory = AstFactory(ctx) >>> loop = factory.loop_nest(("i", "j", "k"), make_slice[:10,:10,:10], PsBlock([])) - + Args: counters: Sequence of names for the loop counters slices: Sequence of iteration slices; see also `parse_slice` diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index a9f760e9718742bfc16a2bee7c60f14f3f272be3..7d9a501a249379c15b5891fd034b71af3184d428 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -34,6 +34,16 @@ from ..ast.expressions import ( PsRightShift, PsSubscript, PsVectorArrayAccess, + PsRel, + PsEq, + PsNe, + PsLt, + PsGt, + PsLe, + PsGe, + PsAnd, + PsOr, + PsNot, ) from ..constants import PsConstant @@ -46,6 +56,20 @@ class FreezeError(Exception): """Signifies an error during expression freezing.""" +ExprLike = ( + sp.Expr + | sp.Tuple + | sp.core.relational.Relational + | sp.logic.boolalg.BooleanFunction +) +_ExprLike = ( + sp.Expr, + sp.Tuple, + sp.core.relational.Relational, + sp.logic.boolalg.BooleanFunction, +) + + class FreezeExpressions: """Convert expressions and kernels expressed in the SymPy language to the code generator's internal representation. @@ -65,7 +89,7 @@ class FreezeExpressions: pass @overload - def __call__(self, obj: sp.Expr) -> PsExpression: + def __call__(self, obj: ExprLike) -> PsExpression: pass @overload @@ -77,7 +101,7 @@ class FreezeExpressions: return PsBlock([self.visit(asm) for asm in obj.all_assignments]) elif isinstance(obj, AssignmentBase): return cast(PsAssignment, self.visit(obj)) - elif isinstance(obj, sp.Expr): + elif isinstance(obj, _ExprLike): return cast(PsExpression, self.visit(obj)) else: raise PsInputError(f"Don't know how to freeze {obj}") @@ -97,8 +121,8 @@ class FreezeExpressions: raise FreezeError(f"Don't know how to freeze expression {node}") - def visit_expr_like(self, obj: Any) -> PsExpression: - if isinstance(obj, sp.Basic): + def visit_expr_or_builtin(self, obj: Any) -> PsExpression: + if isinstance(obj, _ExprLike): return self.visit_expr(obj) elif isinstance(obj, (int, float, bool)): return PsExpression.make(PsConstant(obj)) @@ -106,7 +130,7 @@ class FreezeExpressions: raise FreezeError(f"Don't know how to freeze {obj}") def visit_expr(self, expr: sp.Basic): - if not isinstance(expr, (sp.Expr, sp.Tuple)): + if not isinstance(expr, _ExprLike): raise FreezeError(f"Cannot freeze {expr} to an expression") return cast(PsExpression, self.visit(expr)) @@ -257,7 +281,9 @@ class FreezeExpressions: array = self._ctx.get_array(field) ptr = array.base_pointer - offsets: list[PsExpression] = [self.visit_expr_like(o) for o in access.offsets] + offsets: list[PsExpression] = [ + self.visit_expr_or_builtin(o) for o in access.offsets + ] indices: list[PsExpression] if not access.is_absolute_access: @@ -303,7 +329,7 @@ class FreezeExpressions: ) else: struct_member_name = None - indices = [self.visit_expr_like(i) for i in access.index] + indices = [self.visit_expr_or_builtin(i) for i in access.index] if not indices: # For canonical representation, there must always be at least one index dimension indices = [PsExpression.make(PsConstant(0))] @@ -371,5 +397,35 @@ class FreezeExpressions: args = tuple(self.visit_expr(arg) for arg in expr.args) return PsCall(PsMathFunction(MathFunctions.Max), args) - def map_CastFunc(self, cast_expr: CastFunc): + def map_CastFunc(self, cast_expr: CastFunc) -> PsCast: return PsCast(cast_expr.dtype, self.visit_expr(cast_expr.expr)) + + def map_Relational(self, rel: sp.core.relational.Relational) -> PsRel: + arg1, arg2 = [self.visit_expr(arg) for arg in rel.args] + match rel.rel_op: # type: ignore + case "==": + return PsEq(arg1, arg2) + case "!=": + return PsNe(arg1, arg2) + case ">=": + return PsGe(arg1, arg2) + case "<=": + return PsLe(arg1, arg2) + case ">": + return PsGt(arg1, arg2) + case "<": + return PsLt(arg1, arg2) + case other: + raise FreezeError(f"Unsupported relation: {other}") + + def map_And(self, conj: sp.logic.And) -> PsAnd: + arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] + return PsAnd(arg1, arg2) + + def map_Or(self, conj: sp.logic.Or) -> PsOr: + arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] + return PsOr(arg1, arg2) + + def map_Not(self, conj: sp.Not) -> PsNot: + arg = self.visit_expr(conj.args[0]) + return PsNot(arg) diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 5a093031cb4a498a4a00ff2020df0d69747a2b70..ba215f822ea7372211bf764425d44e44487cc46b 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -125,7 +125,7 @@ class FullIterationSpace(IterationSpace): archetype_field: Field | None = None, ): """Create an iteration space from a sequence of slices, optionally over an archetype field. - + Args: ctx: The kernel creation context iteration_slice: The iteration slices for each dimension; for valid formats, see `AstFactory.parse_slice` @@ -157,6 +157,7 @@ class FullIterationSpace(IterationSpace): ] from .ast_factory import AstFactory + factory = AstFactory(ctx) def to_dim(slic: slice, size: PsSymbol | PsConstant | None, ctr: PsSymbol): diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 707b02c667ee1aae815ea4d7176bdb7694eb95d5..301540c592e328ea2d185ec0685d81743bc5e488 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -27,20 +27,20 @@ from ..ast.expressions import ( PsArrayAccess, PsArrayInitList, PsBinOp, - PsBitwiseAnd, - PsBitwiseOr, - PsBitwiseXor, + PsIntOpTrait, + PsNumericOpTrait, + PsBoolOpTrait, PsCall, PsCast, PsDeref, PsAddressOf, PsConstantExpr, - PsIntDiv, - PsLeftShift, PsLookup, - PsRightShift, PsSubscript, PsSymbolExpr, + PsRel, + PsNeg, + PsNot, ) from ..functions import PsMathFunction @@ -167,19 +167,29 @@ class TypeContext: f" Target type: {self._target_type}" ) - case ( - PsIntDiv() - | PsLeftShift() - | PsRightShift() - | PsBitwiseAnd() - | PsBitwiseXor() - | PsBitwiseOr() - ) if not isinstance(self._target_type, PsIntegerType): + case PsNumericOpTrait() if not isinstance( + self._target_type, PsNumericType + ) or isinstance(self._target_type, PsBoolType): + # FIXME: PsBoolType derives from PsNumericType, but is not numeric + raise TypificationError( + f"Numerical operation encountered in non-numerical type context:\n" + f" Expression: {expr}" + f" Type Context: {self._target_type}" + ) + + case PsIntOpTrait() if not isinstance(self._target_type, PsIntegerType): raise TypificationError( f"Integer operation encountered in non-integer type context:\n" f" Expression: {expr}" f" Type Context: {self._target_type}" ) + + case PsBoolOpTrait() if not isinstance(self._target_type, PsBoolType): + raise TypificationError( + f"Boolean operation encountered in non-boolean type context:\n" + f" Expression: {expr}" + f" Type Context: {self._target_type}" + ) # endif expr.dtype = self._target_type @@ -297,7 +307,7 @@ class Typifier: self.visit_expr(rhs, tc_rhs) case PsConditional(cond, branch_true, branch_false): - cond_tc = TypeContext(PsBoolType(const=True)) + cond_tc = TypeContext(PsBoolType()) self.visit_expr(cond, cond_tc) self.visit(branch_true) @@ -420,11 +430,33 @@ class Typifier: tc.apply_dtype(member_type, expr) + case PsRel(op1, op2): + args_tc = TypeContext() + self.visit_expr(op1, args_tc) + self.visit_expr(op2, args_tc) + + if args_tc.target_type is None: + raise TypificationError( + f"Unable to determine type of arguments to relation: {expr}" + ) + if not isinstance(args_tc.target_type, PsNumericType): + raise TypificationError( + f"Invalid type in arguments to relation\n" + f" Expression: {expr}\n" + f" Arguments Type: {args_tc.target_type}" + ) + + tc.apply_dtype(PsBoolType(), expr) + case PsBinOp(op1, op2): self.visit_expr(op1, tc) self.visit_expr(op2, tc) tc.infer_dtype(expr) + case PsNeg(op) | PsNot(op): + self.visit_expr(op, tc) + tc.infer_dtype(expr) + case PsCall(function, args): match function: case PsMathFunction(): diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index 27fcdfac2b6730448b46f0d784292ca9cb577684..ef4861aa7dd3b7e92275cb61a651e7f3bb0875c4 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -14,7 +14,7 @@ from ..ast.expressions import ( PsSymbolExpr, PsAdd, ) -from ..ast.logical_expressions import PsLt, PsAnd +from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType from ..symbols import PsSymbol diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 269435257bcd8dbb486290fe1d3f35aee8e21319..99aac49df90ba694ce9501eec389bd22f00a4070 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,4 +1,5 @@ import sympy as sp +import pytest from pystencils import Assignment, fields @@ -15,8 +16,16 @@ from pystencils.backend.ast.expressions import ( PsExpression, PsIntDiv, PsLeftShift, - PsMul, PsRightShift, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsLt, + PsLe, + PsGt, + PsGe ) from pystencils.backend.constants import PsConstant from pystencils.backend.kernelcreation import ( @@ -33,7 +42,6 @@ from pystencils.sympyextensions.integer_functions import ( bitwise_xor, int_div, int_power_of_2, - modulo_floor, ) @@ -145,3 +153,44 @@ def test_freeze_integer_functions(): for fasm, correct in zip(fasms, should): assert fasm.structurally_equal(correct) + + +def test_freeze_booleans(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + x, y, z = sp.symbols("x, y, z") + + expr1 = freeze(sp.Not(sp.And(x, y))) + assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) + + expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) + assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) + + +@pytest.mark.parametrize("rel_pair", [ + (sp.Eq, PsEq), + (sp.Ne, PsNe), + (sp.Lt, PsLt), + (sp.Gt, PsGt), + (sp.Le, PsLe), + (sp.Ge, PsGe) +]) +def test_freeze_relations(rel_pair): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + sp_op, ps_op = rel_pair + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + z2 = PsExpression.make(ctx.get_symbol("z")) + + x, y, z = sp.symbols("x, y, z") + + expr1 = freeze(sp_op(x, y + z)) + assert expr1.structurally_equal(ps_op(x2, y2 + z2)) diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index d9cc5f9cef240e923a8d80a7651947b1ca762ff4..0afa5b9e8da6dc18bb47fb2fdf49933b3f7d28d9 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -6,11 +6,30 @@ from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType -from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression -from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp +from pystencils.backend.ast.structural import ( + PsDeclaration, + PsAssignment, + PsExpression, + PsConditional, + PsBlock, +) +from pystencils.backend.ast.expressions import ( + PsConstantExpr, + PsSymbolExpr, + PsBinOp, + PsAnd, + PsOr, + PsNot, + PsEq, + PsNe, + PsGe, + PsLe, + PsGt, + PsLt, +) from pystencils.backend.constants import PsConstant from pystencils.types import constify -from pystencils.types.quick import Fp, create_type, create_numeric_type +from pystencils.types.quick import Fp, Bool, create_type, create_numeric_type from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError @@ -289,3 +308,57 @@ def test_typify_constant_clones(): assert expr_clone.operand1.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None + + +def test_typify_bools_and_relations(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + true = PsConstantExpr(PsConstant(True, Bool())) + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + + expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q)))) + expr = typify(expr) + + assert expr.dtype == Bool(const=True) + + +def test_bool_in_numerical_context(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + true = PsConstantExpr(PsConstant(True, Bool())) + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + + expr = true + (p - q) + with pytest.raises(TypificationError): + typify(expr) + + +@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe]) +def test_typify_conditionals(rel): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + + cond = PsConditional(rel(x, y), PsBlock([])) + cond = typify(cond) + assert cond.condition.dtype == Bool(const=True) + + +def test_invalid_conditions(): + ctx = KernelCreationContext() + typify = Typifier(ctx) + + x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] + p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] + + cond = PsConditional(x + y, PsBlock([])) + with pytest.raises(TypificationError): + typify(cond) + + cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) + with pytest.raises(TypificationError): + typify(cond)