Skip to content
Snippets Groups Projects
Commit 05c5c2f5 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon: Committed by Frederik Hennig
Browse files

Clarify semantics of fancy integer division functions.

parent f7cad358
Branches
No related tags found
1 merge request!417Clarify semantics of fancy integer division functions.
......@@ -414,12 +414,19 @@ class FreezeExpressions:
return PsBitwiseOr(*args)
case integer_functions.int_power_of_2():
return PsLeftShift(PsExpression.make(PsConstant(1)), args[0])
# TODO: what exactly are the semantics?
# case integer_functions.modulo_floor():
# case integer_functions.div_floor()
# TODO: requires if *expression*
# case integer_functions.modulo_ceil():
# case integer_functions.div_ceil():
case integer_functions.round_to_multiple_towards_zero():
return PsIntDiv(args[0], args[1]) * args[1]
case integer_functions.ceil_to_multiple():
return (
PsIntDiv(
args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1]
)
* args[1]
)
case integer_functions.div_ceil():
return PsIntDiv(
args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1]
)
case AddressOf():
return PsAddressOf(*args)
case _:
......
......@@ -9,6 +9,9 @@ from .sympyextensions.integer_functions import (
int_div,
int_rem,
int_power_of_2,
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
__all__ = [
......@@ -20,4 +23,7 @@ __all__ = [
"int_div",
"int_rem",
"int_power_of_2",
"round_to_multiple_towards_zero",
"ceil_to_multiple",
"div_ceil",
]
import sympy as sp
import warnings
from pystencils.sympyextensions import is_integer_sequence
......@@ -46,17 +47,19 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming
class int_div(IntegerFunctionTwoArgsMixIn):
"""C-style round-to-zero integer division"""
def _eval_op(self, arg1, arg2):
from ..utils import c_intdiv
return c_intdiv(arg1, arg2)
class int_rem(IntegerFunctionTwoArgsMixIn):
"""C-style round-to-zero integer remainder"""
def _eval_op(self, arg1, arg2):
from ..utils import c_rem
return c_rem(arg1, arg2)
......@@ -68,66 +71,65 @@ class int_power_of_2(IntegerFunctionTwoArgsMixIn):
# noinspection PyPep8Naming
class modulo_floor(sp.Function):
"""Returns the next smaller integer divisible by given divisor.
class round_to_multiple_towards_zero(IntegerFunctionTwoArgsMixIn):
"""Returns the next smaller/equal in magnitude integer divisible by given
divisor.
Examples:
>>> modulo_floor(9, 4)
>>> round_to_multiple_towards_zero(9, 4)
8
>>> modulo_floor(11, 4)
>>> round_to_multiple_towards_zero(11, -4)
8
>>> modulo_floor(12, 4)
>>> round_to_multiple_towards_zero(12, 4)
12
>>> round_to_multiple_towards_zero(-9, 4)
-8
>>> round_to_multiple_towards_zero(-9, -4)
-8
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return (int(integer) // int(divisor)) * divisor
else:
return super().__new__(cls, integer, divisor)
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
# TODO: Implement this in FreezeExpressions
# def to_c(self, print_func):
# dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
# assert dtype.is_int()
# return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]),
# print_func(self.args[1]), dtype=dtype)
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1, arg2) * arg2
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# noinspection PyPep8Naming
class modulo_ceil(sp.Function):
"""Returns the next bigger integer divisible by given divisor.
class ceil_to_multiple(IntegerFunctionTwoArgsMixIn):
"""For positive input, returns the next greater/equal integer divisible
by given divisor. The return value is unspecified if either argument is
negative.
Examples:
>>> modulo_ceil(9, 4)
>>> ceil_to_multiple(9, 4)
12
>>> modulo_ceil(11, 4)
>>> ceil_to_multiple(11, 4)
12
>>> modulo_ceil(12, 4)
>>> ceil_to_multiple(12, 4)
12
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor
else:
return super().__new__(cls, integer, divisor)
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
# TODO: Implement this in FreezeExpressions
# def to_c(self, print_func):
# dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
# assert dtype.is_int()
# code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1 + arg2 - 1, arg2) * arg2
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# noinspection PyPep8Naming
class div_ceil(sp.Function):
"""Integer division that is always rounded up
class div_ceil(IntegerFunctionTwoArgsMixIn):
"""For positive input, integer division that is always rounded up, i.e.
`div_ceil(a, b) = ceil(div(a, b))`. The return value is unspecified if
either argument is negative.
Examples:
>>> div_ceil(9, 4)
......@@ -135,45 +137,46 @@ class div_ceil(sp.Function):
>>> div_ceil(8, 4)
2
"""
nargs = 2
is_integer = True
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1
else:
return super().__new__(cls, integer, divisor)
@classmethod
def eval(cls, arg1, arg2):
from ..utils import c_intdiv
# TODO: Implement this in FreezeExpressions
# def to_c(self, print_func):
# dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
# assert dtype.is_int()
# code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
if is_integer_sequence((arg1, arg2)):
return c_intdiv(arg1 + arg2 - 1, arg2)
def _eval_op(self, arg1, arg2):
return self.eval(arg1, arg2)
# Deprecated functions.
# noinspection PyPep8Naming
class div_floor(sp.Function):
"""Integer division
class modulo_floor:
def __new__(cls, integer, divisor):
warnings.warn(
"`modulo_floor` is deprecated. Use `round_to_multiple_towards_zero` instead.",
DeprecationWarning,
)
return round_to_multiple_towards_zero(integer, divisor)
Examples:
>>> div_floor(9, 4)
2
>>> div_floor(8, 4)
2
"""
nargs = 2
is_integer = True
# noinspection PyPep8Naming
class modulo_ceil(sp.Function):
def __new__(cls, integer, divisor):
warnings.warn(
"`modulo_ceil` is deprecated. Use `ceil_to_multiple` instead.",
DeprecationWarning,
)
return ceil_to_multiple(integer, divisor)
# noinspection PyPep8Naming
class div_floor(sp.Function):
def __new__(cls, integer, divisor):
if is_integer_sequence((integer, divisor)):
return integer // divisor
else:
return super().__new__(cls, integer, divisor)
# TODO: Implement this in FreezeExpressions
# def to_c(self, print_func):
# dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1])))
# assert dtype.is_int()
# code = "(({dtype})({0}) / ({dtype})({1}))"
# return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype)
warnings.warn(
"`div_floor` is deprecated. Use `int_div` instead.",
DeprecationWarning,
)
return int_div(integer, divisor)
......@@ -58,6 +58,9 @@ from pystencils.sympyextensions.integer_functions import (
bitwise_xor,
int_div,
int_power_of_2,
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
......@@ -153,10 +156,13 @@ def test_freeze_integer_functions():
z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype))
x, y, z = sp.symbols("x, y, z")
one = PsExpression.make(PsConstant(1))
asms = [
Assignment(z, int_div(x, y)),
Assignment(z, int_power_of_2(x, y)),
# Assignment(z, modulo_floor(x, y)),
Assignment(z, round_to_multiple_towards_zero(x, y)),
Assignment(z, ceil_to_multiple(x, y)),
Assignment(z, div_ceil(x, y)),
]
fasms = [freeze(asm) for asm in asms]
......@@ -164,7 +170,9 @@ def test_freeze_integer_functions():
should = [
PsDeclaration(z2, PsIntDiv(x2, y2)),
PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)),
# PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)),
PsDeclaration(z2, PsIntDiv(x2, y2) * y2),
PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2) * y2),
PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2)),
]
for fasm, correct in zip(fasms, should):
......
......@@ -3,6 +3,7 @@ import numpy as np
import sympy as sp
import pystencils
from pystencils import Assignment
from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp
......@@ -13,10 +14,18 @@ from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part
from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta
from pystencils import Assignment
from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts)
from pystencils.sympyextensions.fast_approximation import (
fast_division,
fast_inv_sqrt,
fast_sqrt,
insert_fast_divisions,
insert_fast_sqrts,
)
from pystencils.sympyextensions.integer_functions import (
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
def test_utility():
......@@ -39,10 +48,10 @@ def test_utility():
def test_replace_second_order_products():
x, y = sympy.symbols('x y')
x, y = sympy.symbols("x y")
expr = 4 * x * y
expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2)
expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2)
expected_expr_positive = 2 * ((x + y) ** 2 - x**2 - y**2)
expected_expr_negative = 2 * (-((x - y) ** 2) + x**2 + y**2)
result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
assert result == expected_expr_positive
......@@ -55,15 +64,17 @@ def test_replace_second_order_products():
result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
assert result == expected_expr_positive
a = [Assignment(sympy.symbols('z'), x + y)]
replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a)
a = [Assignment(sympy.symbols("z"), x + y)]
replace_second_order_products(
expr, search_symbols=[x, y], positive=True, replace_mixed=a
)
assert len(a) == 2
assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4
def test_remove_higher_order_terms():
x, y = sympy.symbols('x y')
x, y = sympy.symbols("x y")
expr = sympy.Mul(x, y)
......@@ -81,19 +92,19 @@ def test_remove_higher_order_terms():
def test_complete_the_squares_in_exp():
a, b, c, s, n = sympy.symbols('a b c s n')
expr = a * s ** 2 + b * s + c
a, b, c, s, n = sympy.symbols("a b c s n")
expr = a * s**2 + b * s + c
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expr
expr = sympy.exp(a * s ** 2 + b * s + c)
expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a))
expr = sympy.exp(a * s**2 + b * s + c)
expected_result = sympy.exp(a * s**2 + c - b**2 / (4 * a))
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expected_result
def test_extract_most_common_factor():
x, y = sympy.symbols('x y')
x, y = sympy.symbols("x y")
expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
most_common_factor = extract_most_common_factor(expr)
......@@ -115,98 +126,98 @@ def test_extract_most_common_factor():
def test_count_operations():
x, y, z = sympy.symbols('x y z')
expr = 1/x + y * sympy.sqrt(z)
x, y, z = sympy.symbols("x y z")
expr = 1 / x + y * sympy.sqrt(z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['muls'] == 1
assert ops['divs'] == 1
assert ops['sqrts'] == 1
assert ops["adds"] == 1
assert ops["muls"] == 1
assert ops["divs"] == 1
assert ops["sqrts"] == 1
expr = 1 / sympy.sqrt(z)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0
assert ops['muls'] == 0
assert ops['divs'] == 1
assert ops['sqrts'] == 1
assert ops["adds"] == 0
assert ops["muls"] == 0
assert ops["divs"] == 1
assert ops["sqrts"] == 1
expr = sympy.Rel(1 / sympy.sqrt(z), 5)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0
assert ops['muls'] == 0
assert ops['divs'] == 1
assert ops['sqrts'] == 1
assert ops["adds"] == 0
assert ops["muls"] == 0
assert ops["divs"] == 1
assert ops["sqrts"] == 1
expr = sympy.sqrt(x + y)
expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_sqrts'] == 1
assert ops["fast_sqrts"] == 1
expr = sympy.sqrt(x / y)
expr = insert_fast_divisions(expr).atoms(fast_division)
ops = count_operations(*expr, only_type=None)
assert ops['fast_div'] == 1
assert ops["fast_div"] == 1
expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y))
expr = pystencils.Assignment(sympy.Symbol("tmp"), 3 / sympy.sqrt(x + y))
expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
ops = count_operations(*expr, only_type=None)
assert ops['fast_inv_sqrts'] == 1
assert ops["fast_inv_sqrts"] == 1
expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops["adds"] == 1
expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100)
expr = sympy.Pow(1 / x + y * sympy.sqrt(z), 100)
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['muls'] == 99
assert ops['divs'] == 1
assert ops['sqrts'] == 1
assert ops["adds"] == 1
assert ops["muls"] == 99
assert ops["divs"] == 1
assert ops["sqrts"] == 1
expr = x / y
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
assert ops["divs"] == 1
expr = x + z / y + z
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 2
assert ops['divs'] == 1
assert ops["adds"] == 2
assert ops["divs"] == 1
expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
expr = sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['muls'] == 99
assert ops["muls"] == 99
expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1
assert ops['muls'] == 99
assert ops["divs"] == 1
assert ops["muls"] == 99
expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))
expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1
assert ops['divs'] == 1
assert ops['muls'] == 99
assert ops["adds"] == 1
assert ops["divs"] == 1
assert ops["muls"] == 99
def test_common_denominator():
x = sympy.symbols('x')
x = sympy.symbols("x")
expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
cm = common_denominator(expr)
assert cm == 6
def test_get_symmetric_part():
x, y, z = sympy.symbols('x y z')
expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3
expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3
sym_part = get_symmetric_part(expr, sympy.symbols(f'y z'))
x, y, z = sympy.symbols("x y z")
expr = x / 9 - y**2 / 6 + z**2 / 3 + z / 3
expected_result = x / 9 - y**2 / 6 + z**2 / 3
sym_part = get_symmetric_part(expr, sympy.symbols(f"y z"))
assert sym_part == expected_result
def test_simplify_by_equality():
x, y, z = sp.symbols('x, y, z')
p, q = sp.symbols('p, q')
x, y, z = sp.symbols("x, y, z")
p, q = sp.symbols("p, q")
# Let x = y + z
expr = x * p - y * p + z * q
......@@ -219,9 +230,24 @@ def test_simplify_by_equality():
expr = x * (y + z) - y * z
expr = simplify_by_equality(expr, x, y, z)
assert expr == x*y + z**2
assert expr == x * y + z**2
# Let x = y + 2
expr = x * p - 2 * p
expr = simplify_by_equality(expr, x, y, 2)
assert expr == y * p
def test_integer_functions():
assert round_to_multiple_towards_zero(9, 4) == 8
assert round_to_multiple_towards_zero(11, -4) == 8
assert round_to_multiple_towards_zero(12, 4) == 12
assert round_to_multiple_towards_zero(-9, 4) == -8
assert round_to_multiple_towards_zero(-9, -4) == -8
assert ceil_to_multiple(9, 4) == 12
assert ceil_to_multiple(11, 4) == 12
assert ceil_to_multiple(12, 4) == 12
assert div_ceil(9, 4) == 3
assert div_ceil(8, 4) == 2
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment