Skip to content
Snippets Groups Projects
Commit f694044b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Introduce C remainder:

 - Ast node
 - Typification (as IntOpTrait)
 - Emission
 - Constant Folding
parent ef0ddbf5
No related branches found
No related tags found
1 merge request!393Ternary Expressions, Improved Integer Divisions, and Iteration Space Fix
Pipeline #67188 passed
......@@ -633,6 +633,12 @@ class PsIntDiv(PsBinOp, PsIntOpTrait):
pass
class PsRem(PsBinOp, PsIntOpTrait):
"""C-style integer division remainder"""
# TODO: Implement remainder Python operator
pass
class PsLeftShift(PsBinOp, PsIntOpTrait):
@property
def python_operator(self) -> Callable[[Any, Any], Any] | None:
......
......@@ -26,6 +26,7 @@ from .ast.expressions import (
PsConstantExpr,
PsDeref,
PsDiv,
PsRem,
PsIntDiv,
PsLeftShift,
PsLookup,
......@@ -378,6 +379,8 @@ class CAstPrinter:
return ("*", Ops.Mul)
case PsDiv() | PsIntDiv():
return ("/", Ops.Div)
case PsRem():
return ("%", Ops.Rem)
case PsLeftShift():
return ("<<", Ops.LeftShift)
case PsRightShift():
......
......@@ -15,6 +15,8 @@ from ..ast.expressions import (
PsSub,
PsMul,
PsDiv,
PsIntDiv,
PsRem,
PsAnd,
PsOr,
PsRel,
......@@ -199,9 +201,16 @@ class EliminateConstants:
case PsMul(other_op, PsConstantExpr(c)) if c.value == 1:
return other_op, all(subtree_constness)
case PsDiv(other_op, PsConstantExpr(c)) if c.value == 1:
case PsDiv(other_op, PsConstantExpr(c)) | PsIntDiv(
other_op, PsConstantExpr(c)
) if c.value == 1:
return other_op, all(subtree_constness)
# Trivial remainder at division by one
case PsRem(other_op, PsConstantExpr(c)) if c.value == 1:
zero = self._typify(PsConstantExpr(PsConstant(0, c.get_dtype())))
return zero, True
# Multiplicative dominance: 0 * x = 0
case PsMul(PsConstantExpr(c), other_op) if c.value == 0:
return PsConstantExpr(c), True
......
......@@ -55,17 +55,18 @@ def test_printing_integer_functions():
PsBitwiseOr,
PsBitwiseXor,
PsIntDiv,
PsRem
)
expr = PsBitwiseAnd(
PsBitwiseXor(
PsBitwiseXor(j, k),
PsBitwiseOr(PsLeftShift(i, PsRightShift(j, k)), PsIntDiv(i, k)),
),
) + PsRem(i, k),
i,
)
code = cprint(expr)
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) & i"
assert code == "(j ^ k ^ (i << (j >> k) | i / k)) + i % k & i"
def test_logical_precedence():
......
......@@ -10,7 +10,8 @@ from pystencils.backend.ast.expressions import (
PsNot,
PsEq,
PsGt,
PsTernary
PsTernary,
PsRem
)
from pystencils.types.quick import Int, Fp, Bool
......@@ -87,6 +88,24 @@ def test_zero_dominance():
assert result.structurally_equal(i0)
def test_divisions():
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx)
expr = typify(f3p5 / f1)
result = elim(expr)
assert result.structurally_equal(f3p5)
expr = typify(i3 / i1)
result = elim(expr)
assert result.structurally_equal(i3)
expr = typify(PsRem(i3, i1))
result = elim(expr)
assert result.structurally_equal(i0)
def test_boolean_folding():
ctx = KernelCreationContext()
typify = Typifier(ctx)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment