Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
Showing
with 376 additions and 20 deletions
...@@ -421,6 +421,54 @@ class PsCall(PsExpression): ...@@ -421,6 +421,54 @@ class PsCall(PsExpression):
return super().structurally_equal(other) and self._function == other._function return super().structurally_equal(other) and self._function == other._function
class PsTernary(PsExpression):
"""Ternary operator."""
__match_args__ = ("condition", "case_then", "case_else")
def __init__(
self, cond: PsExpression, then: PsExpression, els: PsExpression
) -> None:
super().__init__()
self._cond = cond
self._then = then
self._else = els
@property
def condition(self) -> PsExpression:
return self._cond
@property
def case_then(self) -> PsExpression:
return self._then
@property
def case_else(self) -> PsExpression:
return self._else
def clone(self) -> PsExpression:
return PsTernary(self._cond.clone(), self._then.clone(), self._else.clone())
def get_children(self) -> tuple[PsExpression, ...]:
return (self._cond, self._then, self._else)
def set_child(self, idx: int, c: PsAstNode):
idx = range(3)[idx]
match idx:
case 0:
self._cond = failing_cast(PsExpression, c)
case 1:
self._then = failing_cast(PsExpression, c)
case 2:
self._else = failing_cast(PsExpression, c)
def __str__(self) -> str:
return f"PsTernary({self._cond}, {self._then}, {self._else})"
def __repr__(self) -> str:
return f"PsTernary({repr(self._cond)}, {repr(self._then)}, {repr(self._else)})"
class PsNumericOpTrait: class PsNumericOpTrait:
"""Trait for operations valid only on numerical types""" """Trait for operations valid only on numerical types"""
...@@ -582,9 +630,21 @@ class PsDiv(PsBinOp, PsNumericOpTrait): ...@@ -582,9 +630,21 @@ class PsDiv(PsBinOp, PsNumericOpTrait):
class PsIntDiv(PsBinOp, PsIntOpTrait): class PsIntDiv(PsBinOp, PsIntOpTrait):
"""C-like integer division (round to zero).""" """C-like integer division (round to zero)."""
# python_operator not implemented because both floordiv and truediv have @property
# different semantics. def python_operator(self) -> Callable[[Any, Any], Any]:
pass from .util import c_intdiv
return c_intdiv
class PsRem(PsBinOp, PsIntOpTrait):
"""C-style integer division remainder"""
@property
def python_operator(self) -> Callable[[Any, Any], Any]:
from .util import c_rem
return c_rem
class PsLeftShift(PsBinOp, PsIntOpTrait): class PsLeftShift(PsBinOp, PsIntOpTrait):
......
...@@ -36,3 +36,14 @@ class AstEqWrapper: ...@@ -36,3 +36,14 @@ class AstEqWrapper:
# TODO: consider replacing this with smth. more performant # TODO: consider replacing this with smth. more performant
# TODO: Check that repr is implemented by all AST nodes # TODO: Check that repr is implemented by all AST nodes
return hash(repr(self._node)) return hash(repr(self._node))
def c_intdiv(num, denom):
"""C-style integer division"""
return int(num / denom)
def c_rem(num, denom):
"""C-style integer remainder"""
div = c_intdiv(num, denom)
return num - div * denom
...@@ -26,6 +26,7 @@ from .ast.expressions import ( ...@@ -26,6 +26,7 @@ from .ast.expressions import (
PsConstantExpr, PsConstantExpr,
PsDeref, PsDeref,
PsDiv, PsDiv,
PsRem,
PsIntDiv, PsIntDiv,
PsLeftShift, PsLeftShift,
PsLookup, PsLookup,
...@@ -37,6 +38,7 @@ from .ast.expressions import ( ...@@ -37,6 +38,7 @@ from .ast.expressions import (
PsSymbolExpr, PsSymbolExpr,
PsLiteralExpr, PsLiteralExpr,
PsVectorArrayAccess, PsVectorArrayAccess,
PsTernary,
PsAnd, PsAnd,
PsOr, PsOr,
PsNot, PsNot,
...@@ -112,6 +114,8 @@ class Ops(Enum): ...@@ -112,6 +114,8 @@ class Ops(Enum):
LogicOr = (15, LR.Left) LogicOr = (15, LR.Left)
Ternary = (16, LR.Right)
Weakest = (17, LR.Middle) Weakest = (17, LR.Middle)
def __init__(self, pred: int, assoc: LR) -> None: def __init__(self, pred: int, assoc: LR) -> None:
...@@ -329,6 +333,19 @@ class CAstPrinter: ...@@ -329,6 +333,19 @@ class CAstPrinter:
type_str = target_type.c_string() type_str = target_type.c_string()
return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast) return pc.parenthesize(f"({type_str}) {operand_code}", Ops.Cast)
case PsTernary(cond, then, els):
pc.push_op(Ops.Ternary, LR.Left)
cond_code = self.visit(cond, pc)
pc.switch_branch(LR.Middle)
then_code = self.visit(then, pc)
pc.switch_branch(LR.Right)
else_code = self.visit(els, pc)
pc.pop_op()
return pc.parenthesize(
f"{cond_code} ? {then_code} : {else_code}", Ops.Ternary
)
case PsArrayInitList(items): case PsArrayInitList(items):
pc.push_op(Ops.Weakest, LR.Middle) pc.push_op(Ops.Weakest, LR.Middle)
items_str = ", ".join(self.visit(item, pc) for item in items) items_str = ", ".join(self.visit(item, pc) for item in items)
...@@ -362,6 +379,8 @@ class CAstPrinter: ...@@ -362,6 +379,8 @@ class CAstPrinter:
return ("*", Ops.Mul) return ("*", Ops.Mul)
case PsDiv() | PsIntDiv(): case PsDiv() | PsIntDiv():
return ("/", Ops.Div) return ("/", Ops.Div)
case PsRem():
return ("%", Ops.Rem)
case PsLeftShift(): case PsLeftShift():
return ("<<", Ops.LeftShift) return ("<<", Ops.LeftShift)
case PsRightShift(): case PsRightShift():
......
...@@ -36,6 +36,7 @@ from ..ast.expressions import ( ...@@ -36,6 +36,7 @@ from ..ast.expressions import (
PsRightShift, PsRightShift,
PsSubscript, PsSubscript,
PsVectorArrayAccess, PsVectorArrayAccess,
PsTernary,
PsRel, PsRel,
PsEq, PsEq,
PsNe, PsNe,
...@@ -391,6 +392,27 @@ class FreezeExpressions: ...@@ -391,6 +392,27 @@ class FreezeExpressions:
case _: case _:
raise FreezeError(f"Unsupported function: {func}") raise FreezeError(f"Unsupported function: {func}")
def map_Piecewise(self, expr: sp.Piecewise) -> PsTernary:
from sympy.functions.elementary.piecewise import ExprCondPair
cases: list[ExprCondPair] = cast(list[ExprCondPair], expr.args)
if cases[-1].cond != sp.true:
raise FreezeError(
"The last case of a `Piecewise` must be the fallback case, its condition must always be `True`."
)
conditions = [self.visit_expr(c.cond) for c in cases[:-1]]
subexprs = [self.visit_expr(c.expr) for c in cases]
last_expr = subexprs.pop()
ternary = PsTernary(conditions.pop(), subexprs.pop(), last_expr)
while conditions:
ternary = PsTernary(conditions.pop(), subexprs.pop(), ternary)
return ternary
def map_Min(self, expr: sp.Min) -> PsCall: def map_Min(self, expr: sp.Min) -> PsCall:
args = tuple(self.visit_expr(arg) for arg in expr.args) args = tuple(self.visit_expr(arg) for arg in expr.args)
return PsCall(PsMathFunction(MathFunctions.Min), args) return PsCall(PsMathFunction(MathFunctions.Min), args)
......
...@@ -11,7 +11,7 @@ from ...field import Field, FieldType ...@@ -11,7 +11,7 @@ from ...field import Field, FieldType
from ..symbols import PsSymbol from ..symbols import PsSymbol
from ..constants import PsConstant from ..constants import PsConstant
from ..ast.expressions import PsExpression, PsConstantExpr from ..ast.expressions import PsExpression, PsConstantExpr, PsTernary, PsEq, PsRem
from ..arrays import PsLinearizedArray from ..arrays import PsLinearizedArray
from ..ast.util import failing_cast from ..ast.util import failing_cast
from ...types import PsStructType, constify from ...types import PsStructType, constify
...@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace): ...@@ -210,14 +210,37 @@ class FullIterationSpace(IterationSpace):
return self._archetype_field return self._archetype_field
def actual_iterations(self, dimension: int | None = None) -> PsExpression: def actual_iterations(self, dimension: int | None = None) -> PsExpression:
from .typification import Typifier
from ..transformations import EliminateConstants
typify = Typifier(self._ctx)
fold = EliminateConstants(self._ctx)
if dimension is None: if dimension is None:
return reduce( return fold(
mul, (self.actual_iterations(d) for d in range(len(self.dimensions))) typify(
reduce(
mul,
(
self.actual_iterations(d)
for d in range(len(self.dimensions))
),
)
)
) )
else: else:
dim = self.dimensions[dimension] dim = self.dimensions[dimension]
one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype)) one = PsConstantExpr(PsConstant(1, self._ctx.index_dtype))
return one + (dim.stop - dim.start - one) / dim.step zero = PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
return fold(
typify(
PsTernary(
PsEq(PsRem((dim.stop - dim.start), dim.step), zero),
(dim.stop - dim.start) / dim.step,
(dim.stop - dim.start) / dim.step + one,
)
)
)
def compressed_counter(self) -> PsExpression: def compressed_counter(self) -> PsExpression:
"""Expression counting the actual number of items processed at the iteration defined by the counter tuple. """Expression counting the actual number of items processed at the iteration defined by the counter tuple.
......
...@@ -32,6 +32,7 @@ from ..ast.expressions import ( ...@@ -32,6 +32,7 @@ from ..ast.expressions import (
PsNumericOpTrait, PsNumericOpTrait,
PsBoolOpTrait, PsBoolOpTrait,
PsCall, PsCall,
PsTernary,
PsCast, PsCast,
PsDeref, PsDeref,
PsAddressOf, PsAddressOf,
...@@ -446,6 +447,14 @@ class Typifier: ...@@ -446,6 +447,14 @@ class Typifier:
tc.apply_dtype(member_type, expr) tc.apply_dtype(member_type, expr)
case PsTernary(cond, then, els):
cond_tc = TypeContext(target_type=PsBoolType())
self.visit_expr(cond, cond_tc)
self.visit_expr(then, tc)
self.visit_expr(els, tc)
tc.infer_dtype(expr)
case PsRel(op1, op2): case PsRel(op1, op2):
args_tc = TypeContext() args_tc = TypeContext()
self.visit_expr(op1, args_tc) self.visit_expr(op1, args_tc)
......
...@@ -15,6 +15,8 @@ from ..ast.expressions import ( ...@@ -15,6 +15,8 @@ from ..ast.expressions import (
PsSub, PsSub,
PsMul, PsMul,
PsDiv, PsDiv,
PsIntDiv,
PsRem,
PsAnd, PsAnd,
PsOr, PsOr,
PsRel, PsRel,
...@@ -27,6 +29,7 @@ from ..ast.expressions import ( ...@@ -27,6 +29,7 @@ from ..ast.expressions import (
PsLt, PsLt,
PsGt, PsGt,
PsNe, PsNe,
PsTernary,
) )
from ..ast.util import AstEqWrapper from ..ast.util import AstEqWrapper
...@@ -198,9 +201,16 @@ class EliminateConstants: ...@@ -198,9 +201,16 @@ class EliminateConstants:
case PsMul(other_op, PsConstantExpr(c)) if c.value == 1: case PsMul(other_op, PsConstantExpr(c)) if c.value == 1:
return other_op, all(subtree_constness) return other_op, all(subtree_constness)
case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1: case PsDiv(other_op, PsConstantExpr(c)) | PsIntDiv(
other_op, PsConstantExpr(c)
) if c.value == 1:
return other_op, all(subtree_constness) return other_op, all(subtree_constness)
# Trivial remainder at division by one
case PsRem(other_op, PsConstantExpr(c)) if c.value == 1:
zero = self._typify(PsConstantExpr(PsConstant(0, c.get_dtype())))
return zero, True
# Multiplicative dominance: 0 * x = 0 # Multiplicative dominance: 0 * x = 0
case PsMul(PsConstantExpr(c), other_op) if c.value == 0: case PsMul(PsConstantExpr(c), other_op) if c.value == 0:
return PsConstantExpr(c), True return PsConstantExpr(c), True
...@@ -247,6 +257,13 @@ class EliminateConstants: ...@@ -247,6 +257,13 @@ class EliminateConstants:
false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType()))) false = self._typify(PsConstantExpr(PsConstant(False, PsBoolType())))
return false, True return false, True
# Trivial ternaries
case PsTernary(PsConstantExpr(c), then, els):
if c.value:
return then, subtree_constness[1]
else:
return els, subtree_constness[2]
# end match: no idempotence or dominance encountered # end match: no idempotence or dominance encountered
# Detect constant expressions # Detect constant expressions
...@@ -299,9 +316,8 @@ class EliminateConstants: ...@@ -299,9 +316,8 @@ class EliminateConstants:
) )
elif isinstance(expr, PsDiv): elif isinstance(expr, PsDiv):
if is_int: if is_int:
pass from ..ast.util import c_intdiv
# TODO: C integer division! folded = PsConstant(c_intdiv(v1, v2), dtype)
# folded = PsConstant(v1 // v2, dtype)
elif isinstance(dtype, PsIeeeFloatType): elif isinstance(dtype, PsIeeeFloatType):
folded = PsConstant(v1 / v2, dtype) folded = PsConstant(v1 / v2, dtype)
......
...@@ -5,7 +5,6 @@ from pystencils import Assignment, fields ...@@ -5,7 +5,6 @@ from pystencils import Assignment, fields
from pystencils.backend.ast.structural import ( from pystencils.backend.ast.structural import (
PsAssignment, PsAssignment,
PsBlock,
PsDeclaration, PsDeclaration,
) )
from pystencils.backend.ast.expressions import ( from pystencils.backend.ast.expressions import (
...@@ -14,6 +13,7 @@ from pystencils.backend.ast.expressions import ( ...@@ -14,6 +13,7 @@ from pystencils.backend.ast.expressions import (
PsBitwiseOr, PsBitwiseOr,
PsBitwiseXor, PsBitwiseXor,
PsExpression, PsExpression,
PsTernary,
PsIntDiv, PsIntDiv,
PsLeftShift, PsLeftShift,
PsRightShift, PsRightShift,
...@@ -33,6 +33,7 @@ from pystencils.backend.kernelcreation import ( ...@@ -33,6 +33,7 @@ from pystencils.backend.kernelcreation import (
FreezeExpressions, FreezeExpressions,
FullIterationSpace, FullIterationSpace,
) )
from pystencils.backend.kernelcreation.freeze import FreezeError
from pystencils.sympyextensions.integer_functions import ( from pystencils.sympyextensions.integer_functions import (
bit_shift_left, bit_shift_left,
...@@ -194,3 +195,28 @@ def test_freeze_relations(rel_pair): ...@@ -194,3 +195,28 @@ def test_freeze_relations(rel_pair):
expr1 = freeze(sp_op(x, y + z)) expr1 = freeze(sp_op(x, y + z))
assert expr1.structurally_equal(ps_op(x2, y2 + z2)) assert expr1.structurally_equal(ps_op(x2, y2 + z2))
def test_freeze_piecewise():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
p, q, x, y, z = sp.symbols("p, q, x, y, z")
p2 = PsExpression.make(ctx.get_symbol("p"))
q2 = PsExpression.make(ctx.get_symbol("q"))
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z"))
piecewise = sp.Piecewise((x, p), (y, q), (z, True))
expr = freeze(piecewise)
assert isinstance(expr, PsTernary)
should = PsTernary(p2, x2, PsTernary(q2, y2, z2))
assert expr.structurally_equal(should)
piecewise = sp.Piecewise((x, p), (y, q), (z, sp.Or(p, q)))
with pytest.raises(FreezeError):
freeze(piecewise)
import pytest import pytest
from pystencils.field import Field from pystencils import make_slice, Field, create_type
from pystencils.sympyextensions.typed_sympy import TypedSymbol, create_type from pystencils.sympyextensions.typed_sympy import TypedSymbol
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace from pystencils.backend.kernelcreation import KernelCreationContext, FullIterationSpace
from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression from pystencils.backend.ast.expressions import PsAdd, PsConstantExpr, PsExpression
from pystencils.backend.kernelcreation.typification import TypificationError from pystencils.backend.kernelcreation.typification import TypificationError
from pystencils.types import PsTypeError from pystencils.types.quick import Int
def test_slices(): def test_slices():
...@@ -36,12 +36,12 @@ def test_slices(): ...@@ -36,12 +36,12 @@ def test_slices():
op.structurally_equal(PsExpression.make(archetype_arr.shape[0])) op.structurally_equal(PsExpression.make(archetype_arr.shape[0]))
for op in dims[0].stop.children for op in dims[0].stop.children
) )
assert isinstance(dims[1].stop, PsAdd) and any( assert isinstance(dims[1].stop, PsAdd) and any(
op.structurally_equal(PsExpression.make(archetype_arr.shape[1])) op.structurally_equal(PsExpression.make(archetype_arr.shape[1]))
for op in dims[1].stop.children for op in dims[1].stop.children
) )
assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2])) assert dims[2].stop.structurally_equal(PsExpression.make(archetype_arr.shape[2]))
...@@ -58,3 +58,28 @@ def test_invalid_slices(): ...@@ -58,3 +58,28 @@ def test_invalid_slices():
islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),) islice = (slice(1, -1, TypedSymbol("w", dtype=create_type("double"))),)
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
FullIterationSpace.create_from_slice(ctx, islice, archetype_field) FullIterationSpace.create_from_slice(ctx, islice, archetype_field)
def test_iteration_count():
ctx = KernelCreationContext()
i, j, k = [PsExpression.make(ctx.get_symbol(x, ctx.index_dtype)) for x in "ijk"]
zero = PsExpression.make(PsConstant(0, ctx.index_dtype))
two = PsExpression.make(PsConstant(2, ctx.index_dtype))
three = PsExpression.make(PsConstant(3, ctx.index_dtype))
ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[three : i-two, 1:8:3]
)
iters = [ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal((i - two) - three)
assert iters[1].structurally_equal(three)
empty_ispace = FullIterationSpace.create_from_slice(
ctx, make_slice[4:4:1, 4:4:7]
)
iters = [empty_ispace.actual_iterations(coord) for coord in range(2)]
assert iters[0].structurally_equal(zero)
assert iters[1].structurally_equal(zero)
...@@ -27,6 +27,7 @@ from pystencils.backend.ast.expressions import ( ...@@ -27,6 +27,7 @@ from pystencils.backend.ast.expressions import (
PsGt, PsGt,
PsLt, PsLt,
PsCall, PsCall,
PsTernary
) )
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction from pystencils.backend.functions import CFunction
...@@ -365,6 +366,31 @@ def test_invalid_conditions(): ...@@ -365,6 +366,31 @@ def test_invalid_conditions():
with pytest.raises(TypificationError): with pytest.raises(TypificationError):
typify(cond) typify(cond)
def test_typify_ternary():
ctx = KernelCreationContext()
typify = Typifier(ctx)
x, y = [PsExpression.make(ctx.get_symbol(name, Fp(32))) for name in "xy"]
a, b = [PsExpression.make(ctx.get_symbol(name, Int(32))) for name in "ab"]
p, q = [PsExpression.make(ctx.get_symbol(name, Bool())) for name in "pq"]
expr = PsTernary(p, x, y)
expr = typify(expr)
assert expr.dtype == Fp(32, const=True)
expr = PsTernary(PsAnd(p, q), a, b + a)
expr = typify(expr)
assert expr.dtype == Int(32, const=True)
expr = PsTernary(PsAnd(p, q), a, x)
with pytest.raises(TypificationError):
typify(expr)
expr = PsTernary(y, a, b)
with pytest.raises(TypificationError):
typify(expr)
def test_cfunction(): def test_cfunction():
ctx = KernelCreationContext() ctx = KernelCreationContext()
......
...@@ -55,17 +55,18 @@ def test_printing_integer_functions(): ...@@ -55,17 +55,18 @@ def test_printing_integer_functions():
PsBitwiseOr, PsBitwiseOr,
PsBitwiseXor, PsBitwiseXor,
PsIntDiv, PsIntDiv,
PsRem
) )
expr = PsBitwiseAnd( expr = PsBitwiseAnd(
PsBitwiseXor( PsBitwiseXor(
PsBitwiseXor(j, k), PsBitwiseXor(j, k),
PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)), PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)),
), ) + PsRem(i, k),
i, i,
) )
code = cprint(expr) code = cprint(expr)
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i" assert code == "(j ^ k ^ (i << (j >> k) | i / k)) + i % k & i"
def test_logical_precedence(): def test_logical_precedence():
...@@ -124,3 +125,32 @@ def test_relations_precedence(): ...@@ -124,3 +125,32 @@ def test_relations_precedence():
expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z))) expr = PsOr(PsNe(x, y), PsNot(PsGt(y, z)))
code = cprint(expr) code = cprint(expr)
assert code == "x != y || !(y > z)" assert code == "x != y || !(y > z)"
def test_ternary():
from pystencils.backend.ast.expressions import PsTernary
from pystencils.backend.ast.expressions import PsNot, PsAnd, PsOr
p, q = [PsExpression.make(PsSymbol(x, Bool())) for x in "pq"]
x, y, z = [PsExpression.make(PsSymbol(x, Fp(32))) for x in "xyz"]
cprint = CAstPrinter()
expr = PsTernary(p, x, y)
code = cprint(expr)
assert code == "p ? x : y"
expr = PsTernary(PsAnd(p, q), x + y, z)
code = cprint(expr)
assert code == "p && q ? x + y : z"
expr = PsTernary(p, PsTernary(q, x, y), z)
code = cprint(expr)
assert code == "p ? (q ? x : y) : z"
expr = PsTernary(p, x, PsTernary(q, y, z))
code = cprint(expr)
assert code == "p ? x : q ? y : z"
expr = PsTernary(PsTernary(p, q, PsOr(p, q)), x, y)
code = cprint(expr)
assert code == "(p ? q : p || q) ? x : y"
import pytest
from pystencils import make_slice from pystencils import make_slice
from pystencils.backend.kernelcreation import ( from pystencils.backend.kernelcreation import (
KernelCreationContext, KernelCreationContext,
...@@ -62,6 +64,8 @@ def test_eliminate_nested_conditional(): ...@@ -62,6 +64,8 @@ def test_eliminate_nested_conditional():
def test_isl(): def test_isl():
pytest.importorskip("islpy")
ctx = KernelCreationContext() ctx = KernelCreationContext()
factory = AstFactory(ctx) factory = AstFactory(ctx)
typify = Typifier(ctx) typify = Typifier(ctx)
......
...@@ -10,6 +10,9 @@ from pystencils.backend.ast.expressions import ( ...@@ -10,6 +10,9 @@ from pystencils.backend.ast.expressions import (
PsNot, PsNot,
PsEq, PsEq,
PsGt, PsGt,
PsTernary,
PsRem,
PsIntDiv
) )
from pystencils.types.quick import Int, Fp, Bool from pystencils.types.quick import Int, Fp, Bool
...@@ -26,8 +29,10 @@ f1 = PsExpression.make(PsConstant(1.0, Fp(32))) ...@@ -26,8 +29,10 @@ f1 = PsExpression.make(PsConstant(1.0, Fp(32)))
i0 = PsExpression.make(PsConstant(0, Int(32))) i0 = PsExpression.make(PsConstant(0, Int(32)))
i1 = PsExpression.make(PsConstant(1, Int(32))) i1 = PsExpression.make(PsConstant(1, Int(32)))
im1 = PsExpression.make(PsConstant(-1, Int(32)))
i3 = PsExpression.make(PsConstant(3, Int(32))) i3 = PsExpression.make(PsConstant(3, Int(32)))
i4 = PsExpression.make(PsConstant(4, Int(32)))
im3 = PsExpression.make(PsConstant(-3, Int(32))) im3 = PsExpression.make(PsConstant(-3, Int(32)))
i12 = PsExpression.make(PsConstant(12, Int(32))) i12 = PsExpression.make(PsConstant(12, Int(32)))
...@@ -86,6 +91,64 @@ def test_zero_dominance(): ...@@ -86,6 +91,64 @@ def test_zero_dominance():
assert result.structurally_equal(i0) assert result.structurally_equal(i0)
def test_divisions():
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(f3p5 / f1)
result = elim(expr)
assert result.structurally_equal(f3p5)
expr = typify(i3 / i1)
result = elim(expr)
assert result.structurally_equal(i3)
expr = typify(PsRem(i3, i1))
result = elim(expr)
assert result.structurally_equal(i0)
expr = typify(PsIntDiv(i12, i3))
result = elim(expr)
assert result.structurally_equal(i4)
expr = typify(i12 / i3)
result = elim(expr)
assert result.structurally_equal(i4)
expr = typify(PsIntDiv(i4, i3))
result = elim(expr)
assert result.structurally_equal(i1)
expr = typify(PsIntDiv(-i4, i3))
result = elim(expr)
assert result.structurally_equal(im1)
expr = typify(PsIntDiv(i4, -i3))
result = elim(expr)
assert result.structurally_equal(im1)
expr = typify(PsIntDiv(-i4, -i3))
result = elim(expr)
assert result.structurally_equal(i1)
expr = typify(PsRem(i4, i3))
result = elim(expr)
assert result.structurally_equal(i1)
expr = typify(PsRem(-i4, i3))
result = elim(expr)
assert result.structurally_equal(im1)
expr = typify(PsRem(i4, -i3))
result = elim(expr)
assert result.structurally_equal(i1)
expr = typify(PsRem(-i4, -i3))
result = elim(expr)
assert result.structurally_equal(im1)
def test_boolean_folding(): def test_boolean_folding():
ctx = KernelCreationContext() ctx = KernelCreationContext()
typify = Typifier(ctx) typify = Typifier(ctx)
...@@ -128,3 +191,25 @@ def test_relations_folding(): ...@@ -128,3 +191,25 @@ def test_relations_folding():
expr = typify(PsGt(x + y, f1 * (x + y))) expr = typify(PsGt(x + y, f1 * (x + y)))
result = elim(expr) result = elim(expr)
assert result.structurally_equal(false) assert result.structurally_equal(false)
def test_ternary_folding():
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(PsTernary(true, x, y))
result = elim(expr)
assert result.structurally_equal(x)
expr = typify(PsTernary(false, x, y))
result = elim(expr)
assert result.structurally_equal(y)
expr = typify(PsTernary(PsGt(i1, i0), PsTernary(PsEq(i1, i12), x, y), z))
result = elim(expr)
assert result.structurally_equal(y)
expr = typify(PsTernary(PsGt(x, y), x + f0, y * f1))
result = elim(expr)
assert result.structurally_equal(PsTernary(PsGt(x, y), x, y))