import pytest import sympy as sp import numpy as np from typing import cast from pystencils import Assignment, TypedSymbol, Field, FieldType from pystencils.backend.ast.structural import ( PsDeclaration, PsAssignment, PsExpression, PsConditional, PsBlock, ) from pystencils.backend.ast.expressions import ( PsConstantExpr, PsSymbolExpr, PsBinOp, PsAnd, PsOr, PsNot, PsEq, PsNe, PsGe, PsLe, PsGt, PsLt, PsCall, PsTernary ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import CFunction from pystencils.types import constify, create_type, create_numeric_type from pystencils.types.quick import Fp, Int, Bool from pystencils.backend.kernelcreation.context import KernelCreationContext from pystencils.backend.kernelcreation.freeze import FreezeExpressions from pystencils.backend.kernelcreation.typification import Typifier, TypificationError from pystencils.sympyextensions.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_xor, bitwise_or, ) def test_typify_simple(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) x, y, z = sp.symbols("x, y, z") asm = Assignment(z, 2 * x + y) fasm = freeze(asm) fasm = typify(fasm) assert isinstance(fasm, PsDeclaration) def check(expr): assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value == 2 assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" assert symb.dtype == ctx.default_dtype case PsBinOp(op1, op2): check(op1) check(op2) case _: pytest.fail(f"Unexpected expression: {expr}") check(fasm.lhs) check(fasm.rhs) def test_rhs_constness(): default_type = Fp(32) ctx = KernelCreationContext(default_dtype=default_type) freeze = FreezeExpressions(ctx) typify = Typifier(ctx) f = Field.create_generic( "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM ) f_const = Field.create_generic( "f_const", 1, index_shape=(1,), dtype=constify(default_type), field_type=FieldType.CUSTOM, ) x, y, z = sp.symbols("x, y, z") # Right-hand sides should always get const types asm = typify(freeze(Assignment(x, f.absolute_access([0], [0])))) assert asm.rhs.get_dtype().const asm = typify( freeze( Assignment( f.absolute_access([0], [0]), f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y, ) ) ) assert asm.rhs.get_dtype().const def test_lhs_constness(): default_type = Fp(32) ctx = KernelCreationContext(default_dtype=default_type) freeze = FreezeExpressions(ctx) typify = Typifier(ctx) f = Field.create_generic( "f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM ) f_const = Field.create_generic( "f_const", 1, index_shape=(1,), dtype=constify(default_type), field_type=FieldType.CUSTOM, ) x, y, z = sp.symbols("x, y, z") # Assignment RHS may not be const asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y))) assert not asm.lhs.get_dtype().const # Cannot assign to const left-hand side with pytest.raises(TypificationError): _ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y))) np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) struct_type = constify(create_type(np_struct)) struct_field = Field.create_generic( "struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM ) with pytest.raises(TypificationError): _ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x))) # Const LHS is only OK in declarations q = ctx.get_symbol("q", Fp(32, const=True)) ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q)) ast = typify(ast) assert ast.lhs.dtype == Fp(32, const=True) ast = PsAssignment(PsExpression.make(q), PsExpression.make(q)) with pytest.raises(TypificationError): typify(ast) def test_typify_structs(): ctx = KernelCreationContext(default_dtype=Fp(32)) freeze = FreezeExpressions(ctx) typify = Typifier(ctx) np_struct = np.dtype([("size", np.uint32), ("data", np.float32)]) f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM) x = sp.Symbol("x") # Good asm = Assignment(x, f.absolute_access((0,), "data")) fasm = freeze(asm) fasm = typify(fasm) asm = Assignment(f.absolute_access((0,), "data"), x) fasm = freeze(asm) fasm = typify(fasm) # Bad asm = Assignment(x, f.absolute_access((0,), "size")) fasm = freeze(asm) with pytest.raises(TypificationError): fasm = typify(fasm) def test_contextual_typing(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) x, y, z = sp.symbols("x, y, z") expr = freeze(2 * x + 3 * y + z - 4) expr = typify(expr) def check(expr): assert expr.dtype == constify(ctx.default_dtype) match expr: case PsConstantExpr(cs): assert cs.value in (2, 3, -4) assert cs.dtype == constify(ctx.default_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" assert symb.dtype == ctx.default_dtype case PsBinOp(op1, op2): check(op1) check(op2) case _: pytest.fail(f"Unexpected expression: {expr}") check(expr) def test_erronous_typing(): ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64)) freeze = FreezeExpressions(ctx) typify = Typifier(ctx) x, y, z = sp.symbols("x, y, z") q = TypedSymbol("q", np.float32) w = TypedSymbol("w", np.float16) expr = freeze(2 * x + 3 * y + q - 4) with pytest.raises(TypificationError): typify(expr) asm = Assignment(q, 3 - w) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) asm = Assignment(q, 3 - x) fasm = freeze(asm) with pytest.raises(TypificationError): typify(fasm) def test_typify_integer_binops(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) ctx.get_symbol("x", ctx.index_dtype) ctx.get_symbol("y", ctx.index_dtype) x, y = sp.symbols("x, y") expr = bit_shift_left( bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2) ) expr = freeze(expr) expr = typify(expr) def check(expr): match expr: case PsConstantExpr(cs): assert cs.value == 2 assert cs.dtype == constify(ctx.index_dtype) case PsSymbolExpr(symb): assert symb.name in "xyz" assert symb.dtype == ctx.index_dtype case PsBinOp(op1, op2): check(op1) check(op2) case _: pytest.fail(f"Unexpected expression: {expr}") check(expr) def test_typify_integer_binops_floating_arg(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) x = sp.Symbol("x") expr = bit_shift_left(x, 2) expr = freeze(expr) with pytest.raises(TypificationError): expr = typify(expr) def test_typify_integer_binops_in_floating_context(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) typify = Typifier(ctx) ctx.get_symbol("i", ctx.index_dtype) x, i = sp.symbols("x, i") expr = x + bit_shift_left(i, 2) expr = freeze(expr) with pytest.raises(TypificationError): expr = typify(expr) def test_typify_constant_clones(): ctx = KernelCreationContext(default_dtype=Fp(32)) typify = Typifier(ctx) c = PsConstantExpr(PsConstant(3.0)) x = PsSymbolExpr(ctx.get_symbol("x")) expr = c + x expr_clone = expr.clone() expr = typify(expr) assert expr_clone.operand1.dtype is None assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None def test_typify_bools_and_relations(): ctx = KernelCreationContext() typify = Typifier(ctx) true = PsConstantExpr(PsConstant(True, Bool())) p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q)))) expr = typify(expr) assert expr.dtype == Bool(const=True) def test_bool_in_numerical_context(): ctx = KernelCreationContext() typify = Typifier(ctx) true = PsConstantExpr(PsConstant(True, Bool())) p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] expr = true + (p - q) with pytest.raises(TypificationError): typify(expr) @pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe]) def test_typify_conditionals(rel): ctx = KernelCreationContext() typify = Typifier(ctx) x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] cond = PsConditional(rel(x, y), PsBlock([])) cond = typify(cond) assert cond.condition.dtype == Bool(const=True) def test_invalid_conditions(): ctx = KernelCreationContext() typify = Typifier(ctx) x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] cond = PsConditional(x + y, PsBlock([])) with pytest.raises(TypificationError): typify(cond) cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([])) with pytest.raises(TypificationError): typify(cond) def test_typify_ternary(): ctx = KernelCreationContext() typify = Typifier(ctx) x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"] p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"] expr = PsTernary(p, x, y) expr = typify(expr) assert expr.dtype == Fp(32, const=True) expr = PsTernary(PsAnd(p, q), a, b + a) expr = typify(expr) assert expr.dtype == Int(32, const=True) expr = PsTernary(PsAnd(p, q), a, x) with pytest.raises(TypificationError): typify(expr) expr = PsTernary(y, a, b) with pytest.raises(TypificationError): typify(expr) def test_cfunction(): ctx = KernelCreationContext() typify = Typifier(ctx) x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"] p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"] def _threeway(x: np.float32, y: np.float32) -> np.int32: assert False threeway = CFunction.parse(_threeway) result = typify(PsCall(threeway, [x, y])) assert result.get_dtype() == Int(32, const=True) assert result.args[0].get_dtype() == Fp(32, const=True) assert result.args[1].get_dtype() == Fp(32, const=True) with pytest.raises(TypificationError): _ = typify(PsCall(threeway, (x, p)))