Forked from
pycodegen / pystencils
49 commits behind, 134 commits ahead of the upstream repository.
-
Frederik Hennig authoredFrederik Hennig authored
test_typification.py 8.13 KiB
import pytest
import sympy as sp
import numpy as np
from typing import cast
from pystencils import Assignment, TypedSymbol, Field, FieldType
from pystencils.backend.ast.structural import PsDeclaration, PsAssignment, PsExpression
from pystencils.backend.ast.expressions import PsConstantExpr, PsSymbolExpr, PsBinOp
from pystencils.backend.constants import PsConstant
from pystencils.types import constify
from pystencils.types.quick import Fp, create_type, create_numeric_type
from pystencils.backend.kernelcreation.context import KernelCreationContext
from pystencils.backend.kernelcreation.freeze import FreezeExpressions
from pystencils.backend.kernelcreation.typification import Typifier, TypificationError
from pystencils.sympyextensions.integer_functions import (
bit_shift_left,
bit_shift_right,
bitwise_and,
bitwise_xor,
bitwise_or,
)
def test_typify_simple():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
fasm = typify(fasm)
assert isinstance(fasm, PsDeclaration)
def check(expr):
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(fasm.lhs)
check(fasm.rhs)
def test_rhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Right-hand sides should always get const types
asm = typify(freeze(Assignment(x, f.absolute_access([0], [0]))))
assert asm.rhs.get_dtype().const
asm = typify(
freeze(
Assignment(
f.absolute_access([0], [0]),
f.absolute_access([0], [0]) * f_const.absolute_access([0], [0]) * x + y,
)
)
)
assert asm.rhs.get_dtype().const
def test_lhs_constness():
default_type = Fp(32)
ctx = KernelCreationContext(default_dtype=default_type)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
f = Field.create_generic(
"f", 1, index_shape=(1,), dtype=default_type, field_type=FieldType.CUSTOM
)
f_const = Field.create_generic(
"f_const",
1,
index_shape=(1,),
dtype=constify(default_type),
field_type=FieldType.CUSTOM,
)
x, y, z = sp.symbols("x, y, z")
# Assignment RHS may not be const
asm = typify(freeze(Assignment(f.absolute_access([0], [0]), x + y)))
assert not asm.lhs.get_dtype().const
# Cannot assign to const left-hand side
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(f_const.absolute_access([0], [0]), x + y)))
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
struct_type = constify(create_type(np_struct))
struct_field = Field.create_generic(
"struct_field", 1, dtype=struct_type, field_type=FieldType.CUSTOM
)
with pytest.raises(TypificationError):
_ = typify(freeze(Assignment(struct_field.absolute_access([0], "data"), x)))
# Const LHS is only OK in declarations
q = ctx.get_symbol("q", Fp(32, const=True))
ast = PsDeclaration(PsExpression.make(q), PsExpression.make(q))
ast = typify(ast)
assert ast.lhs.dtype == Fp(32, const=True)
ast = PsAssignment(PsExpression.make(q), PsExpression.make(q))
with pytest.raises(TypificationError):
typify(ast)
def test_typify_structs():
ctx = KernelCreationContext(default_dtype=Fp(32))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
np_struct = np.dtype([("size", np.uint32), ("data", np.float32)])
f = Field.create_generic("f", 1, dtype=np_struct, field_type=FieldType.CUSTOM)
x = sp.Symbol("x")
# Good
asm = Assignment(x, f.absolute_access((0,), "data"))
fasm = freeze(asm)
fasm = typify(fasm)
asm = Assignment(f.absolute_access((0,), "data"), x)
fasm = freeze(asm)
fasm = typify(fasm)
# Bad
asm = Assignment(x, f.absolute_access((0,), "size"))
fasm = freeze(asm)
with pytest.raises(TypificationError):
fasm = typify(fasm)
def test_contextual_typing():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
expr = freeze(2 * x + 3 * y + z - 4)
expr = typify(expr)
def check(expr):
assert expr.dtype == constify(ctx.default_dtype)
match expr:
case PsConstantExpr(cs):
assert cs.value in (2, 3, -4)
assert cs.dtype == constify(ctx.default_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.default_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_erronous_typing():
ctx = KernelCreationContext(default_dtype=create_numeric_type(np.float64))
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x, y, z = sp.symbols("x, y, z")
q = TypedSymbol("q", np.float32)
w = TypedSymbol("w", np.float16)
expr = freeze(2 * x + 3 * y + q - 4)
with pytest.raises(TypificationError):
typify(expr)
asm = Assignment(q, 3 - w)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
asm = Assignment(q, 3 - x)
fasm = freeze(asm)
with pytest.raises(TypificationError):
typify(fasm)
def test_typify_integer_binops():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("x", ctx.index_dtype)
ctx.get_symbol("y", ctx.index_dtype)
x, y = sp.symbols("x, y")
expr = bit_shift_left(
bit_shift_right(bitwise_and(2, 2), bitwise_or(x, y)), bitwise_xor(2, 2)
)
expr = freeze(expr)
expr = typify(expr)
def check(expr):
match expr:
case PsConstantExpr(cs):
assert cs.value == 2
assert cs.dtype == constify(ctx.index_dtype)
case PsSymbolExpr(symb):
assert symb.name in "xyz"
assert symb.dtype == ctx.index_dtype
case PsBinOp(op1, op2):
check(op1)
check(op2)
case _:
pytest.fail(f"Unexpected expression: {expr}")
check(expr)
def test_typify_integer_binops_floating_arg():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
x = sp.Symbol("x")
expr = bit_shift_left(x, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_integer_binops_in_floating_context():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol("i", ctx.index_dtype)
x, i = sp.symbols("x, i")
expr = x + bit_shift_left(i, 2)
expr = freeze(expr)
with pytest.raises(TypificationError):
expr = typify(expr)
def test_typify_constant_clones():
ctx = KernelCreationContext(default_dtype=Fp(32))
typify = Typifier(ctx)
c = PsConstantExpr(PsConstant(3.0))
x = PsSymbolExpr(ctx.get_symbol("x"))
expr = c + x
expr_clone = expr.clone()
expr = typify(expr)
assert expr_clone.operand1.dtype is None
assert cast(PsConstantExpr, expr_clone.operand1).constant.dtype is None