diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index fa936e50646f39d02c838e680bdc315a58ce92cb..ff7754ac2e28971797853e4c637ae52f6fe97bed 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -414,12 +414,19 @@ class FreezeExpressions: return PsBitwiseOr(*args) case integer_functions.int_power_of_2(): return PsLeftShift(PsExpression.make(PsConstant(1)), args[0]) - # TODO: what exactly are the semantics? - # case integer_functions.modulo_floor(): - # case integer_functions.div_floor() - # TODO: requires if *expression* - # case integer_functions.modulo_ceil(): - # case integer_functions.div_ceil(): + case integer_functions.round_to_multiple_towards_zero(): + return PsIntDiv(args[0], args[1]) * args[1] + case integer_functions.ceil_to_multiple(): + return ( + PsIntDiv( + args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1] + ) + * args[1] + ) + case integer_functions.div_ceil(): + return PsIntDiv( + args[0] + args[1] - PsExpression.make(PsConstant(1)), args[1] + ) case AddressOf(): return PsAddressOf(*args) case _: diff --git a/src/pystencils/symb.py b/src/pystencils/symb.py index 0c682b26113c70ca2304bc63a15a6aa7e8d8ad9f..8e293405817c2189ebe7428bc1a53bbde8ca8073 100644 --- a/src/pystencils/symb.py +++ b/src/pystencils/symb.py @@ -9,6 +9,9 @@ from .sympyextensions.integer_functions import ( int_div, int_rem, int_power_of_2, + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, ) __all__ = [ @@ -20,4 +23,7 @@ __all__ = [ "int_div", "int_rem", "int_power_of_2", + "round_to_multiple_towards_zero", + "ceil_to_multiple", + "div_ceil", ] diff --git a/src/pystencils/sympyextensions/integer_functions.py b/src/pystencils/sympyextensions/integer_functions.py index eb3bb4ccc79d06d54e320bb0b442ea7dad1c670a..cf25472c89cd4deda18de889e92139e7c2a28067 100644 --- a/src/pystencils/sympyextensions/integer_functions.py +++ b/src/pystencils/sympyextensions/integer_functions.py @@ -1,4 +1,5 @@ import sympy as sp +import warnings from pystencils.sympyextensions import is_integer_sequence @@ -46,17 +47,19 @@ class bitwise_or(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming class int_div(IntegerFunctionTwoArgsMixIn): """C-style round-to-zero integer division""" - + def _eval_op(self, arg1, arg2): from ..utils import c_intdiv + return c_intdiv(arg1, arg2) class int_rem(IntegerFunctionTwoArgsMixIn): """C-style round-to-zero integer remainder""" - + def _eval_op(self, arg1, arg2): from ..utils import c_rem + return c_rem(arg1, arg2) @@ -68,66 +71,65 @@ class int_power_of_2(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming -class modulo_floor(sp.Function): - """Returns the next smaller integer divisible by given divisor. +class round_to_multiple_towards_zero(IntegerFunctionTwoArgsMixIn): + """Returns the next smaller/equal in magnitude integer divisible by given + divisor. Examples: - >>> modulo_floor(9, 4) + >>> round_to_multiple_towards_zero(9, 4) 8 - >>> modulo_floor(11, 4) + >>> round_to_multiple_towards_zero(11, -4) 8 - >>> modulo_floor(12, 4) + >>> round_to_multiple_towards_zero(12, 4) 12 + >>> round_to_multiple_towards_zero(-9, 4) + -8 + >>> round_to_multiple_towards_zero(-9, -4) + -8 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return (int(integer) // int(divisor)) * divisor - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # return "({dtype})(({0}) / ({1})) * ({1})".format(print_func(self.args[0]), - # print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1, arg2) * arg2 + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) # noinspection PyPep8Naming -class modulo_ceil(sp.Function): - """Returns the next bigger integer divisible by given divisor. +class ceil_to_multiple(IntegerFunctionTwoArgsMixIn): + """For positive input, returns the next greater/equal integer divisible + by given divisor. The return value is unspecified if either argument is + negative. Examples: - >>> modulo_ceil(9, 4) + >>> ceil_to_multiple(9, 4) 12 - >>> modulo_ceil(11, 4) + >>> ceil_to_multiple(11, 4) 12 - >>> modulo_ceil(12, 4) + >>> ceil_to_multiple(12, 4) 12 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer if integer % divisor == 0 else ((integer // divisor) + 1) * divisor - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "(({0}) % ({1}) == 0 ? {0} : (({dtype})(({0}) / ({1}))+1) * ({1}))" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1 + arg2 - 1, arg2) * arg2 + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) # noinspection PyPep8Naming -class div_ceil(sp.Function): - """Integer division that is always rounded up +class div_ceil(IntegerFunctionTwoArgsMixIn): + """For positive input, integer division that is always rounded up, i.e. + `div_ceil(a, b) = ceil(div(a, b))`. The return value is unspecified if + either argument is negative. Examples: >>> div_ceil(9, 4) @@ -135,45 +137,46 @@ class div_ceil(sp.Function): >>> div_ceil(8, 4) 2 """ - nargs = 2 - is_integer = True - def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer // divisor if integer % divisor == 0 else (integer // divisor) + 1 - else: - return super().__new__(cls, integer, divisor) + @classmethod + def eval(cls, arg1, arg2): + from ..utils import c_intdiv - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "( ({0}) % ({1}) == 0 ? ({dtype})({0}) / ({dtype})({1}) : ( ({dtype})({0}) / ({dtype})({1}) ) +1 )" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + if is_integer_sequence((arg1, arg2)): + return c_intdiv(arg1 + arg2 - 1, arg2) + + def _eval_op(self, arg1, arg2): + return self.eval(arg1, arg2) + + +# Deprecated functions. # noinspection PyPep8Naming -class div_floor(sp.Function): - """Integer division +class modulo_floor: + def __new__(cls, integer, divisor): + warnings.warn( + "`modulo_floor` is deprecated. Use `round_to_multiple_towards_zero` instead.", + DeprecationWarning, + ) + return round_to_multiple_towards_zero(integer, divisor) - Examples: - >>> div_floor(9, 4) - 2 - >>> div_floor(8, 4) - 2 - """ - nargs = 2 - is_integer = True +# noinspection PyPep8Naming +class modulo_ceil(sp.Function): + def __new__(cls, integer, divisor): + warnings.warn( + "`modulo_ceil` is deprecated. Use `ceil_to_multiple` instead.", + DeprecationWarning, + ) + return ceil_to_multiple(integer, divisor) + + +# noinspection PyPep8Naming +class div_floor(sp.Function): def __new__(cls, integer, divisor): - if is_integer_sequence((integer, divisor)): - return integer // divisor - else: - return super().__new__(cls, integer, divisor) - - # TODO: Implement this in FreezeExpressions - # def to_c(self, print_func): - # dtype = collate_types((get_type_of_expression(self.args[0]), get_type_of_expression(self.args[1]))) - # assert dtype.is_int() - # code = "(({dtype})({0}) / ({dtype})({1}))" - # return code.format(print_func(self.args[0]), print_func(self.args[1]), dtype=dtype) + warnings.warn( + "`div_floor` is deprecated. Use `int_div` instead.", + DeprecationWarning, + ) + return int_div(integer, divisor) diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 2761d9b3fd2a839c5af3fbfd878f0009edfb3981..072882a7bba69b23da377148f9c9875ece5dd231 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -58,6 +58,9 @@ from pystencils.sympyextensions.integer_functions import ( bitwise_xor, int_div, int_power_of_2, + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, ) @@ -153,10 +156,13 @@ def test_freeze_integer_functions(): z2 = PsExpression.make(ctx.get_symbol("z", ctx.index_dtype)) x, y, z = sp.symbols("x, y, z") + one = PsExpression.make(PsConstant(1)) asms = [ Assignment(z, int_div(x, y)), Assignment(z, int_power_of_2(x, y)), - # Assignment(z, modulo_floor(x, y)), + Assignment(z, round_to_multiple_towards_zero(x, y)), + Assignment(z, ceil_to_multiple(x, y)), + Assignment(z, div_ceil(x, y)), ] fasms = [freeze(asm) for asm in asms] @@ -164,7 +170,9 @@ def test_freeze_integer_functions(): should = [ PsDeclaration(z2, PsIntDiv(x2, y2)), PsDeclaration(z2, PsLeftShift(PsExpression.make(PsConstant(1)), x2)), - # PsDeclaration(z2, PsMul(PsIntDiv(x2, y2), y2)), + PsDeclaration(z2, PsIntDiv(x2, y2) * y2), + PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2) * y2), + PsDeclaration(z2, PsIntDiv(x2 + y2 - one, y2)), ] for fasm, correct in zip(fasms, should): diff --git a/tests/test_sympyextensions.py b/tests/test_sympyextensions.py index 05c11996864073be98c6aaea51de10db3867dcfb..ad5d2513b4400db938b8372a83ef43cc9339b35d 100644 --- a/tests/test_sympyextensions.py +++ b/tests/test_sympyextensions.py @@ -3,6 +3,7 @@ import numpy as np import sympy as sp import pystencils +from pystencils import Assignment from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import complete_the_squares_in_exp @@ -13,10 +14,18 @@ from pystencils.sympyextensions import common_denominator from pystencils.sympyextensions import get_symmetric_part from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import kronecker_delta - -from pystencils import Assignment -from pystencils.sympyextensions.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, - insert_fast_divisions, insert_fast_sqrts) +from pystencils.sympyextensions.fast_approximation import ( + fast_division, + fast_inv_sqrt, + fast_sqrt, + insert_fast_divisions, + insert_fast_sqrts, +) +from pystencils.sympyextensions.integer_functions import ( + round_to_multiple_towards_zero, + ceil_to_multiple, + div_ceil, +) def test_utility(): @@ -39,10 +48,10 @@ def test_utility(): def test_replace_second_order_products(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = 4 * x * y - expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2) - expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2) + expected_expr_positive = 2 * ((x + y) ** 2 - x**2 - y**2) + expected_expr_negative = 2 * (-((x - y) ** 2) + x**2 + y**2) result = replace_second_order_products(expr, search_symbols=[x, y], positive=True) assert result == expected_expr_positive @@ -55,15 +64,17 @@ def test_replace_second_order_products(): result = replace_second_order_products(expr, search_symbols=[x, y], positive=None) assert result == expected_expr_positive - a = [Assignment(sympy.symbols('z'), x + y)] - replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a) + a = [Assignment(sympy.symbols("z"), x + y)] + replace_second_order_products( + expr, search_symbols=[x, y], positive=True, replace_mixed=a + ) assert len(a) == 2 assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4 def test_remove_higher_order_terms(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = sympy.Mul(x, y) @@ -81,19 +92,19 @@ def test_remove_higher_order_terms(): def test_complete_the_squares_in_exp(): - a, b, c, s, n = sympy.symbols('a b c s n') - expr = a * s ** 2 + b * s + c + a, b, c, s, n = sympy.symbols("a b c s n") + expr = a * s**2 + b * s + c result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expr - expr = sympy.exp(a * s ** 2 + b * s + c) - expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a)) + expr = sympy.exp(a * s**2 + b * s + c) + expected_result = sympy.exp(a * s**2 + c - b**2 / (4 * a)) result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) assert result == expected_result def test_extract_most_common_factor(): - x, y = sympy.symbols('x y') + x, y = sympy.symbols("x y") expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y) most_common_factor = extract_most_common_factor(expr) @@ -115,98 +126,98 @@ def test_extract_most_common_factor(): def test_count_operations(): - x, y, z = sympy.symbols('x y z') - expr = 1/x + y * sympy.sqrt(z) + x, y, z = sympy.symbols("x y z") + expr = 1 / x + y * sympy.sqrt(z) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['muls'] == 1 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 1 + assert ops["muls"] == 1 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = 1 / sympy.sqrt(z) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 0 - assert ops['muls'] == 0 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 0 + assert ops["muls"] == 0 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = sympy.Rel(1 / sympy.sqrt(z), 5) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 0 - assert ops['muls'] == 0 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 0 + assert ops["muls"] == 0 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = sympy.sqrt(x + y) expr = insert_fast_sqrts(expr).atoms(fast_sqrt) ops = count_operations(*expr, only_type=None) - assert ops['fast_sqrts'] == 1 + assert ops["fast_sqrts"] == 1 expr = sympy.sqrt(x / y) expr = insert_fast_divisions(expr).atoms(fast_division) ops = count_operations(*expr, only_type=None) - assert ops['fast_div'] == 1 + assert ops["fast_div"] == 1 - expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y)) + expr = pystencils.Assignment(sympy.Symbol("tmp"), 3 / sympy.sqrt(x + y)) expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt) ops = count_operations(*expr, only_type=None) - assert ops['fast_inv_sqrts'] == 1 + assert ops["fast_inv_sqrts"] == 1 expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 + assert ops["adds"] == 1 - expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) + expr = sympy.Pow(1 / x + y * sympy.sqrt(z), 100) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['muls'] == 99 - assert ops['divs'] == 1 - assert ops['sqrts'] == 1 + assert ops["adds"] == 1 + assert ops["muls"] == 99 + assert ops["divs"] == 1 + assert ops["sqrts"] == 1 expr = x / y ops = count_operations(expr, only_type=None) - assert ops['divs'] == 1 + assert ops["divs"] == 1 expr = x + z / y + z ops = count_operations(expr, only_type=None) - assert ops['adds'] == 2 - assert ops['divs'] == 1 + assert ops["adds"] == 2 + assert ops["divs"] == 1 - expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['muls'] == 99 + assert ops["muls"] == 99 - expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['divs'] == 1 - assert ops['muls'] == 99 + assert ops["divs"] == 1 + assert ops["muls"] == 99 - expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) + expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False)) ops = count_operations(expr, only_type=None) - assert ops['adds'] == 1 - assert ops['divs'] == 1 - assert ops['muls'] == 99 + assert ops["adds"] == 1 + assert ops["divs"] == 1 + assert ops["muls"] == 99 def test_common_denominator(): - x = sympy.symbols('x') + x = sympy.symbols("x") expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3) cm = common_denominator(expr) assert cm == 6 def test_get_symmetric_part(): - x, y, z = sympy.symbols('x y z') - expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3 - expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3 - sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) + x, y, z = sympy.symbols("x y z") + expr = x / 9 - y**2 / 6 + z**2 / 3 + z / 3 + expected_result = x / 9 - y**2 / 6 + z**2 / 3 + sym_part = get_symmetric_part(expr, sympy.symbols(f"y z")) assert sym_part == expected_result def test_simplify_by_equality(): - x, y, z = sp.symbols('x, y, z') - p, q = sp.symbols('p, q') + x, y, z = sp.symbols("x, y, z") + p, q = sp.symbols("p, q") # Let x = y + z expr = x * p - y * p + z * q @@ -219,9 +230,24 @@ def test_simplify_by_equality(): expr = x * (y + z) - y * z expr = simplify_by_equality(expr, x, y, z) - assert expr == x*y + z**2 + assert expr == x * y + z**2 # Let x = y + 2 expr = x * p - 2 * p expr = simplify_by_equality(expr, x, y, 2) assert expr == y * p + + +def test_integer_functions(): + assert round_to_multiple_towards_zero(9, 4) == 8 + assert round_to_multiple_towards_zero(11, -4) == 8 + assert round_to_multiple_towards_zero(12, 4) == 12 + assert round_to_multiple_towards_zero(-9, 4) == -8 + assert round_to_multiple_towards_zero(-9, -4) == -8 + + assert ceil_to_multiple(9, 4) == 12 + assert ceil_to_multiple(11, 4) == 12 + assert ceil_to_multiple(12, 4) == 12 + + assert div_ceil(9, 4) == 3 + assert div_ceil(8, 4) == 2