Skip to content
Snippets Groups Projects
Commit a3ab67b0 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

add integer functions div_floor, int_power_of_2, modulo_floor

parent 7c3a2e6c
No related branches found
No related tags found
1 merge request!368Integer functions
...@@ -318,7 +318,7 @@ class FreezeExpressions: ...@@ -318,7 +318,7 @@ class FreezeExpressions:
return PsCall(PsMathFunction(MathFunctions.Cos), args) return PsCall(PsMathFunction(MathFunctions.Cos), args)
case sp.tan(): case sp.tan():
return PsCall(PsMathFunction(MathFunctions.Tan), args) return PsCall(PsMathFunction(MathFunctions.Tan), args)
case integer_functions.int_div(): case integer_functions.int_div() | integer_functions.div_floor():
return PsIntDiv(*args) return PsIntDiv(*args)
case integer_functions.bit_shift_left(): case integer_functions.bit_shift_left():
return PsLeftShift(*args) return PsLeftShift(*args)
...@@ -330,6 +330,13 @@ class FreezeExpressions: ...@@ -330,6 +330,13 @@ class FreezeExpressions:
return PsBitwiseXor(*args) return PsBitwiseXor(*args)
case integer_functions.bitwise_or(): case integer_functions.bitwise_or():
return PsBitwiseOr(*args) return PsBitwiseOr(*args)
case integer_functions.int_power_of_2():
return PsLeftShift(PsExpression.make(PsConstant(1)), args[0])
case integer_functions.modulo_floor():
return PsIntDiv(*args) * args[1]
# TODO: requires if *expression*
# case integer_functions.modulo_ceil():
# case integer_functions.div_ceil():
case _: case _:
raise FreezeError(f"Unsupported function: {func}") raise FreezeError(f"Unsupported function: {func}")
......
...@@ -51,6 +51,8 @@ class int_div(IntegerFunctionTwoArgsMixIn): ...@@ -51,6 +51,8 @@ class int_div(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming # noinspection PyPep8Naming
# TODO: What do the *two* arguments mean?
# Apparently, the second is required but ignored?
class int_power_of_2(IntegerFunctionTwoArgsMixIn): class int_power_of_2(IntegerFunctionTwoArgsMixIn):
pass pass
......
...@@ -4,6 +4,7 @@ from pystencils import Assignment, fields ...@@ -4,6 +4,7 @@ from pystencils import Assignment, fields
from pystencils.backend.ast.structural import ( from pystencils.backend.ast.structural import (
PsAssignment, PsAssignment,
PsBlock,
PsDeclaration, PsDeclaration,
) )
from pystencils.backend.ast.expressions import ( from pystencils.backend.ast.expressions import (
...@@ -12,7 +13,9 @@ from pystencils.backend.ast.expressions import ( ...@@ -12,7 +13,9 @@ from pystencils.backend.ast.expressions import (
PsBitwiseOr, PsBitwiseOr,
PsBitwiseXor, PsBitwiseXor,
PsExpression, PsExpression,
PsIntDiv,
PsLeftShift, PsLeftShift,
PsMul,
PsRightShift, PsRightShift,
) )
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
...@@ -26,8 +29,12 @@ from pystencils.sympyextensions.integer_functions import ( ...@@ -26,8 +29,12 @@ from pystencils.sympyextensions.integer_functions import (
bit_shift_left, bit_shift_left,
bit_shift_right, bit_shift_right,
bitwise_and, bitwise_and,
bitwise_xor,
bitwise_or, bitwise_or,
bitwise_xor,
div_floor,
int_div,
int_power_of_2,
modulo_floor,
) )
...@@ -112,3 +119,32 @@ def test_freeze_integer_binops(): ...@@ -112,3 +119,32 @@ def test_freeze_integer_binops():
) )
assert fexpr.structurally_equal(should) assert fexpr.structurally_equal(should)
def test_freeze_integer_functions():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x2 = PsExpression.make(ctx.get_symbol("x", ctx.index_dtype))
y2 = PsExpression.make(ctx.get_symbol("y", ctx.index_dtype))
z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype))
x, y, z = sp.symbols("x, y, z")
asms = [
Assignment(z, int_div(x, y)),
Assignment(z, div_floor(x, y)),
Assignment(z, int_power_of_2(x, y)),
Assignment(z, modulo_floor(x, y)),
]
fasms = [freeze(asm) for asm in asms]
should = [
PsDeclaration(z2, PsIntDiv(x2, y2)),
PsDeclaration(z2, PsIntDiv(x2, y2)),
PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)),
PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)),
]
for fasm, correct in zip(fasms, should):
assert fasm.structurally_equal(correct)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment