Select Git revision
test_basic_usage_llvm.ipynb
test_typification.py 17.23 KiB
import pytest
import sympy as sp
import numpy as np
from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType, AddAugmentedAssignment
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.structural import (
PsDeclaration,
PsAssignment,
PsExpression,
PsConditional,
PsBlock,
)
from pystencils.backend.ast.expressions import (
PsArrayInitList,
PsCast,
PsConstantExpr,
PsSymbolExpr,
PsSubscript,
PsBinOp,
PsAnd,
PsOr,
PsNot,
PsEq,
PsNe,
PsGe,
PsLe,
PsGt,
PsLt,
PsCall,
PsTernary,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.types import constify, create_type, create_numeric_type
from pystencils.types.quick import Fp, Int, Bool, Arr
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_xor,
bitwise_or,
)
def test_typify_simple():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
fasm = typify(fasm)
assert isinstance(fasm, PsDeclaration)
def check(expr):
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(fasm.lhs)
check(fasm.rhs)
def test_rhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Right-hand sides should always get const types
asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
assert asm.rhs.get_dtype().const
asm = typify(
freeze(
Assignment(
f.absolute_access([0], [0]),
f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
)
)
)
assert asm.rhs.get_dtype().const
def test_lhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Assignment RHS may not be const
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
assert not asm.lhs.get_dtype().const
# Cannot assign to const left-hand side
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic(
"struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
)
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
# Const LHS is only OK in declarations
q = ctx.get_symbol("q", Fp(32, const=True))
ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
ast = typify(ast)
assert ast.lhs.dtype == Fp(32, const=True)
ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
with pytest.raises(TypificationError):
typify(ast)
def test_typify_structs():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)], align=True)
f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
x = sp.Symbol("x")
# Good
asm = Assignment(x, f.absolute_access((0,), "data"))
fasm = freeze(asm)
fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad
asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm)
with pytest.raises(TypificationError):
fasm = typify(fasm)
def test_default_typing():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
expr = freeze(2 * x + 3 * y + z - 4)
expr = typify(expr)
def check(expr):
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value in (2, 3, -4)
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_nested_contexts_defaults():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
fortytwo = PsExpression.make(PsConstant(42))
expr = typify(PsEq(fortytwo, fortytwo))
assert expr.dtype == Bool(const=True)
assert fortytwo.dtype == Fp(16, const=True)
assert fortytwo.constant.dtype == Fp(16, const=True)
decl = freeze(Assignment(x, sp.Ge(3, 4, evaluate=False)))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Bool()
for c in dfs_preorder(decl.rhs, lambda n: isinstance(n, PsConstantExpr)):
assert c.dtype == Fp(16, const=True)
assert c.constant.dtype == Fp(16, const=True)
thirtyone = PsExpression.make(PsConstant(31))
decl = PsDeclaration(freeze(y), PsCast(Fp(32), thirtyone))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Fp(32)
assert thirtyone.dtype == Fp(16, const=True)
assert thirtyone.constant.dtype == Fp(16, const=True)
def test_constant_decls():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, 3.0))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Fp(16)
assert decl.rhs.dtype == Fp(16, const=True)
assert decl.rhs.constant.dtype == Fp(16, const=True)
decl = freeze(Assignment(y, 42))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Fp(16)
assert decl.rhs.dtype == Fp(16, const=True)
assert decl.rhs.constant.dtype == Fp(16, const=True)
def test_constant_array_decls():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
decl = freeze(Assignment(x, (1, 2, 3, 4)))
decl = typify(decl)
assert ctx.get_symbol("x").dtype == Arr(Fp(16), 4)
decl = freeze(Assignment(y, ((1, 2, 3, 4), (5, 6, 7, 8))))
decl = typify(decl)
assert ctx.get_symbol("y").dtype == Arr(Arr(Fp(16), 4), 2)
def test_inline_arrays_1d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = TypedSymbol("idx", Int(32))
arr: PsArrayInitList = cast(PsArrayInitList, freeze(sp.Tuple(1, 2, 3, 4)))
decl = PsDeclaration(freeze(x), freeze(y) + PsSubscript(arr, freeze(idx)))
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Fp(16), 4, const=True)
for item in arr.items:
assert item.dtype == Fp(16, const=True)
def test_inline_arrays_3d():
ctx = KernelCreationContext(default_dtype=Fp(16))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y = sp.symbols("x, y")
idx = [TypedSymbol(f"idx_{i}", Int(32)) for i in range(3)]
arr = freeze(sp.Tuple(((1, 2), (3, 4), (5, 6)), ((5, 6), (7, 8), (9, 10))))
decl = PsDeclaration(
freeze(x),
freeze(y)
+ PsSubscript(
PsSubscript(PsSubscript(arr, freeze(idx[0])), freeze(idx[1])),
freeze(idx[2]),
),
)
# The array elements should learn their type from the context, which gets it from `y`
decl = typify(decl)
assert decl.lhs.dtype == Fp(16, const=True)
assert decl.rhs.dtype == Fp(16, const=True)
assert arr.dtype == Arr(Arr(Arr(Fp(16), 2), 3), 2, const=True)
for item in arr.items:
assert item.dtype == Arr(Arr(Fp(16), 2), 3, const=True)
for iitem in item.items:
assert iitem.dtype == Arr(Fp(16), 2, const=True)
for iiitem in iitem.items:
assert iiitem.dtype == Fp(16, const=True)
def test_lhs_inference():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
# Type of the LHS is propagated to untyped RHS symbols
asm = Assignment(x, 3 - q)
fasm = typify(freeze(asm))
assert ctx.get_symbol("x").dtype == Fp(32)
assert fasm.lhs.dtype == constify(Fp(32))
asm = Assignment(y, 3 - w)
fasm = typify(freeze(asm))
assert ctx.get_symbol("y").dtype == Fp(16)
assert fasm.lhs.dtype == constify(Fp(16))
fasm = PsAssignment(PsExpression.make(ctx.get_symbol("z")), freeze(3 - w))
fasm = typify(fasm)
assert ctx.get_symbol("z").dtype == Fp(16)
assert fasm.lhs.dtype == Fp(16)
fasm = PsDeclaration(
PsExpression.make(ctx.get_symbol("r")), PsLe(freeze(q), freeze(2 * q))
)
fasm = typify(fasm)
assert ctx.get_symbol("r").dtype == Bool()
assert fasm.lhs.dtype == constify(Bool())
assert fasm.rhs.dtype == constify(Bool())
def test_erronous_typing():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
expr = freeze(2 * x + 3 * y + q - 4)
with pytest.raises(TypificationError):
typify(expr)
# Conflict between LHS and RHS symbols
asm = Assignment(q, 3 - w)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
# Do not propagate types back from LHS symbols to RHS symbols
asm = Assignment(q, 3 - x)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
asm = AddAugmentedAssignment(z, 3 - q)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
def test_invalid_indices():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
typify = Typifier(ctx)
arr = PsExpression.make(ctx.get_symbol("arr", Arr(Fp(64))))
x, y, z = [PsExpression.make(ctx.get_symbol(x)) for x in "xyz"]
# Using default-typed symbols as array indices is illegal when the default type is a float
fasm = PsAssignment(PsSubscript(arr, x + y), z)
with pytest.raises(TypificationError):
typify(fasm)
fasm = PsAssignment(z, PsSubscript(arr, x + y))
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("x", ctx.index_dtype)
ctx.get_symbol("y", ctx.index_dtype)
x, y = sp.symbols("x, y")
expr = bit_shift_left(
bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
)
expr = freeze(expr)
expr = typify(expr)
def check(expr):
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.index_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.index_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_typify_integer_binops_floating_arg():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
expr = bit_shift_left(x, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_integer_binops_in_floating_context():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("i", ctx.index_dtype)
x, i = sp.symbols("x, i")
expr = x + bit_shift_left(i, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_constant_clones():
ctx = KernelCreationContext(default_dtype=Fp(32))
typify = Typifier(ctx)
c = PsConstantExpr(PsConstant(3.0))
x = PsSymbolExpr(ctx.get_symbol("x"))
expr = c + x
expr_clone = expr.clone()
expr = typify(expr)
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None
def test_typify_bools_and_relations():
ctx = KernelCreationContext()
typify = Typifier(ctx)
true = PsConstantExpr(PsConstant(True, Bool()))
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
expr = PsAnd(PsEq(x, y), PsAnd(true, PsNot(PsOr(p, q))))
expr = typify(expr)
assert expr.dtype == Bool(const=True)
def test_bool_in_numerical_context():
ctx = KernelCreationContext()
typify = Typifier(ctx)
true = PsConstantExpr(PsConstant(True, Bool()))
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
expr = true + (p - q)
with pytest.raises(TypificationError):
typify(expr)
@pytest.mark.parametrize("rel", [PsEq, PsNe, PsLt, PsGt, PsLe, PsGe])
def test_typify_conditionals(rel):
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
cond = PsConditional(rel(x, y), PsBlock([]))
cond = typify(cond)
assert cond.condition.dtype == Bool(const=True)
def test_invalid_conditions():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
cond = PsConditional(x + y, PsBlock([]))
with pytest.raises(TypificationError):
typify(cond)
cond = PsConditional(PsAnd(p, PsOr(x, q)), PsBlock([]))
with pytest.raises(TypificationError):
typify(cond)
def test_typify_ternary():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
expr = PsTernary(p, x, y)
expr = typify(expr)
assert expr.dtype == Fp(32, const=True)
expr = PsTernary(PsAnd(p, q), a, b + a)
expr = typify(expr)
assert expr.dtype == Int(32, const=True)
expr = PsTernary(PsAnd(p, q), a, x)
with pytest.raises(TypificationError):
typify(expr)
expr = PsTernary(y, a, b)
with pytest.raises(TypificationError):
typify(expr)
def test_cfunction():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
p, q = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "pq"]
def _threeway(x: np.float32, y: np.float32) -> np.int32:
assert False
threeway = CFunction.parse(_threeway)
result = typify(PsCall(threeway, [x, y]))
assert result.get_dtype() == Int(32, const=True)
assert result.args[0].get_dtype() == Fp(32, const=True)
assert result.args[1].get_dtype() == Fp(32, const=True)
with pytest.raises(TypificationError):
_ = typify(PsCall(threeway, (x, p)))