diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 388160f3053e57c19d4edccadb08589d71064896..f3d18f3498079dcd06c2f4ffa61a87ff59e47117 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -78,6 +78,7 @@ class MathFunctions(Enum): ASin = ("asin", 1) ACos = ("acos", 1) ATan = ("atan", 1) + Sqrt = ("sqrt", 1) Abs = ("abs", 1) Floor = ("floor", 1) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 5d18ecf82c51d3687fd4ff7324843031cab705de..4fd09f879dd8d98903753c8709543e0bcc3fd3e1 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -212,40 +212,49 @@ class FreezeExpressions: base = expr.args[0] exponent = expr.args[1] - base_frozen = self.visit_expr(base) - reciprocal = False - expand_product = False - - if exponent.is_Integer: - if exponent == 0: - return PsExpression.make(PsConstant(1)) - - if exponent.is_negative: - reciprocal = True - exponent = -exponent - - if exponent <= sp.Integer( - 5 - ): # TODO: is this a sensible limit? maybe make this configurable. - expand_product = True - - if expand_product: - frozen_expr = reduce( - mul, - [base_frozen] - + [base_frozen.clone() for _ in range(0, int(exponent) - 1)], - ) - else: - exponent_frozen = self.visit_expr(exponent) - frozen_expr = PsMathFunction(MathFunctions.Pow)( - base_frozen, exponent_frozen - ) + expr_frozen = self.visit_expr(base) + + if isinstance(exponent, sp.Rational): + # Decompose rational exponent + num: int = exponent.numerator + denom: int = exponent.denominator - if reciprocal: - one = PsExpression.make(PsConstant(1)) - frozen_expr = one / frozen_expr + if denom <= 2 and abs(num) <= 8: + # At most a square root, and at most eight factors - return frozen_expr + reciprocal = False + + if num < 0: + reciprocal = True + num = -num + + if denom == 2: + expr_frozen = PsMathFunction(MathFunctions.Sqrt)(expr_frozen) + denom = 1 + + assert denom == 1 + + # Pairwise multiplication for logarithmic runtime + factors = [expr_frozen] + [expr_frozen.clone() for _ in range(num - 1)] + while len(factors) > 1: + combined = [x * y for x, y in zip(factors[::2], factors[1::2])] + if len(factors) % 2 == 1: + combined.append(factors[-1]) + factors = combined + + expr_frozen = factors.pop() + + if reciprocal: + one = PsExpression.make(PsConstant(1)) + expr_frozen = one / expr_frozen + + return expr_frozen + + # If we got this far, use pow + exponent_frozen = self.visit_expr(exponent) + expr_frozen = PsMathFunction(MathFunctions.Pow)(expr_frozen, exponent_frozen) + + return expr_frozen def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: value = int(expr) diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 2559ac6d2456e094fe325956264b34ee859edf04..ff41d3a68888347e81e9ff65d1359467666f2e32 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -91,6 +91,7 @@ class CudaPlatform(GenericGpu): | MathFunctions.Log | MathFunctions.Sin | MathFunctions.Cos + | MathFunctions.Sqrt | MathFunctions.Ceil | MathFunctions.Floor ) if dtype.width in (16, 32, 64): diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index b6d7dd551bb70b799241cb0ca6257fd48b57a3d6..fa8e54002679b9727578f923d4d409208268782c 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -43,7 +43,7 @@ class GenericCpu(Platform): @property def required_headers(self) -> set[str]: - return {"<math.h>"} + return {"<cmath>"} def materialize_iteration_space( self, body: PsBlock, ispace: IterationSpace @@ -78,6 +78,7 @@ class GenericCpu(Platform): | MathFunctions.ATan | MathFunctions.ATan2 | MathFunctions.Pow + | MathFunctions.Sqrt | MathFunctions.Floor | MathFunctions.Ceil ): diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index 594c87b145fbe39640321634f5e5e33dd93d7378..e1da9e2237b504c4b6b681d48812e4d079a5b463 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -82,6 +82,7 @@ class SyclPlatform(GenericGpu): | MathFunctions.ATan | MathFunctions.ATan2 | MathFunctions.Pow + | MathFunctions.Sqrt | MathFunctions.Floor | MathFunctions.Ceil ): diff --git a/src/pystencils/backend/platforms/x86.py b/src/pystencils/backend/platforms/x86.py index 7d2fe650fc23a54ea301817d57a3816b6780bd85..fff1433009c1272f0eabd375f92f77bc744e8f21 100644 --- a/src/pystencils/backend/platforms/x86.py +++ b/src/pystencils/backend/platforms/x86.py @@ -202,6 +202,9 @@ class X86VectorCpu(GenericVectorCpu): opstr = expr.function.func.function_name if vtype.width > 256: raise MaterializationError("512bit ceil/floor require SVML.") + + case MathFunctions.Sqrt if vtype.is_float(): + opstr = expr.function.name case MathFunctions.Min | MathFunctions.Max: opstr = expr.function.func.function_name diff --git a/tests/kernelcreation/test_functions.py b/tests/kernelcreation/test_functions.py index 9b7dd2852c73242d8a8549f1a7c39e26024848b2..a4d154d4b0c86ea694bbe94f66372aa2ba3a190c 100644 --- a/tests/kernelcreation/test_functions.py +++ b/tests/kernelcreation/test_functions.py @@ -28,6 +28,7 @@ def unary_function(name, xp): "asin": (sp.asin, xp.arcsin), "acos": (sp.acos, xp.arccos), "atan": (sp.atan, xp.arctan), + "sqrt": (sp.sqrt, xp.sqrt), "abs": (sp.Abs, xp.abs), "floor": (sp.floor, xp.floor), "ceil": (sp.ceiling, xp.ceil), @@ -46,6 +47,82 @@ def binary_function(name, xp): AVAIL_TARGETS = Target.available_targets() +@pytest.fixture +def function_domain(function_name, dtype): + eps = 1e-6 + rng = np.random.default_rng() + + match function_name: + case ( + "exp" | "sin" | "cos" | "sinh" | "cosh" | "atan" | "abs" | "floor" | "ceil" + ): + return np.concatenate( + [ + [0.0, -1.0, 1.0], + rng.uniform(-0.1, 0.1, 5), + rng.uniform(-5.0, 5.0, 8), + rng.uniform(-10, 10, 8), + ] + ).astype(dtype) + case "tan": + return np.concatenate( + [ + [0.0, -1.0, 1.0], + rng.uniform(-np.pi / 2.0 + eps, np.pi / 2.0 - eps, 13), + ] + ).astype(dtype) + case "asin" | "acos": + return np.concatenate( + [ + [0.0, 1.0, -1.0], + rng.uniform(-1.0, 1.0, 13), + ] + ).astype(dtype) + case "log" | "sqrt": + return np.concatenate( + [ + [1.0], + rng.uniform(eps, 0.1, 7), + rng.uniform(eps, 1.0, 8), + rng.uniform(eps, 1e6, 8), + ] + ).astype(dtype) + case "min" | "max" | "atan2": + return np.concatenate( + [ + rng.uniform(-0.1, 0.1, 8), + rng.uniform(-5.0, 5.0, 8), + rng.uniform(-10, 10, 8), + ] + ).astype(dtype), np.concatenate( + [ + rng.uniform(-0.1, 0.1, 8), + rng.uniform(-5.0, 5.0, 8), + rng.uniform(-10, 10, 8), + ] + ).astype( + dtype + ) + case "pow": + return np.concatenate( + [ + [0., 1., 1.], + rng.uniform(-1., 1., 8), + rng.uniform(0., 5., 8), + ] + ).astype(dtype), np.concatenate( + [ + [1., 0., 2.], + np.arange(2., 10., 1.), + rng.uniform(-2.0, 2.0, 8), + ] + ).astype( + dtype + ) + case _: + assert False, "I don't know the domain of that function" + + @pytest.mark.parametrize( "function_name, target", list( @@ -70,15 +147,15 @@ AVAIL_TARGETS = Target.available_targets() ["floor", "ceil"], [t for t in AVAIL_TARGETS if Target._AVX512 not in t] ) ) - + list(product(["abs"], AVAIL_TARGETS)), + + list(product(["sqrt", "abs"], AVAIL_TARGETS)), ) @pytest.mark.parametrize("dtype", (np.float32, np.float64)) -def test_unary_functions(gen_config, xp, function_name, dtype): +def test_unary_functions(gen_config, xp, function_name, dtype, function_domain): sp_func, xp_func = unary_function(function_name, xp) resolution = np.finfo(dtype).resolution # Array size should be larger than eight, such that vectorized kernels don't just run their remainder loop - inp = xp.array([0.1, 0.2, 0.0, -0.8, -1.6, -12.592, xp.pi, xp.e, -0.3], dtype=dtype) + inp = xp.array(function_domain) outp = xp.zeros_like(inp) reference = xp_func(inp) @@ -104,15 +181,12 @@ def test_unary_functions(gen_config, xp, function_name, dtype): ), ) @pytest.mark.parametrize("dtype", (np.float32, np.float64)) -def test_binary_functions(gen_config, xp, function_name, dtype): +def test_binary_functions(gen_config, xp, function_name, dtype, function_domain): sp_func, xp_func = binary_function(function_name, xp) resolution: dtype = np.finfo(dtype).resolution - inp = xp.array([0.1, 0.2, 0.3, -0.8, -1.6, -12.592, xp.pi, xp.e, 0.0], dtype=dtype) - inp2 = xp.array( - [3.1, -0.5, 21.409, 11.0, 1.0, -14e3, 2.0 * xp.pi, -xp.e, 0.0], - dtype=dtype, - ) + inp = xp.array(function_domain[0]) + inp2 = xp.array(function_domain[1]) outp = xp.zeros_like(inp) reference = xp_func(inp, inp2) diff --git a/tests/kernelcreation/test_staggered_kernel.py b/tests/kernelcreation/test_staggered_kernel.py index 9bc9e71af036370285475ff48f57c048304c70f4..99b61d07f77ea476395b4fb3d337b65ff5ac42cb 100644 --- a/tests/kernelcreation/test_staggered_kernel.py +++ b/tests/kernelcreation/test_staggered_kernel.py @@ -5,7 +5,7 @@ import pytest import pystencils as ps from pystencils import x_staggered_vector, TypedSymbol -from pystencils.enums import Target +from pystencils import Target class TestStaggeredDiffusion: diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index b7b2ed19e114a3c3d7568fbe6b6ea035848f42e4..f6c8f85b2b3df2289e809728b9e7b014d6428976 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -113,14 +113,8 @@ def test_freeze_fields(): zero = PsExpression.make(PsConstant(0)) - lhs = PsBufferAcc( - f_arr.base_pointer, - (PsExpression.make(counter) + zero, zero) - ) - rhs = PsBufferAcc( - g_arr.base_pointer, - (PsExpression.make(counter) + zero, zero) - ) + lhs = PsBufferAcc(f_arr.base_pointer, (PsExpression.make(counter) + zero, zero)) + rhs = PsBufferAcc(g_arr.base_pointer, (PsExpression.make(counter) + zero, zero)) should = PsAssignment(lhs, rhs) @@ -357,6 +351,80 @@ def test_add_sub(): assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2))) +def test_powers(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x, y, z = sp.symbols("x, y, z") + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + + # Integer powers + expr = freeze(x**2) + assert expr.structurally_equal(x2 * x2) + + expr = freeze(x**3) + assert expr.structurally_equal(x2 * x2 * x2) + + expr = freeze(x**4) + assert expr.structurally_equal((x2 * x2) * (x2 * x2)) + + expr = freeze(x**5) + assert expr.structurally_equal((x2 * x2) * (x2 * x2) * x2) + + # Negative integer powers + one = PsExpression.make(PsConstant(1)) + + expr = freeze(x**-2) + assert expr.structurally_equal(one / (x2 * x2)) + + expr = freeze(x**-3) + assert expr.structurally_equal(one / (x2 * x2 * x2)) + + expr = freeze(x**-4) + assert expr.structurally_equal(one / ((x2 * x2) * (x2 * x2))) + + expr = freeze(x**-5) + assert expr.structurally_equal(one / ((x2 * x2) * (x2 * x2) * x2)) + + # Integer powers of the square root + sqrt = PsMathFunction(MathFunctions.Sqrt) + + expr = freeze(x ** sp.Rational(1, 2)) + assert expr.structurally_equal(sqrt(x2)) + + expr = freeze(x ** sp.Rational(2, 2)) + assert expr.structurally_equal(x2) + + expr = freeze(x ** sp.Rational(3, 2)) + assert expr.structurally_equal(sqrt(x2) * sqrt(x2) * sqrt(x2)) + + expr = freeze(x ** sp.Rational(4, 2)) + assert expr.structurally_equal(x2 * x2) + + expr = freeze(x ** sp.Rational(5, 2)) + assert expr.structurally_equal( + (sqrt(x2) * sqrt(x2)) * (sqrt(x2) * sqrt(x2)) * sqrt(x2) + ) + + # Negative integer powers of sqrt + expr = freeze(x ** sp.Rational(-1, 2)) + assert expr.structurally_equal(one / sqrt(x2)) + + expr = freeze(x ** sp.Rational(-3, 2)) + assert expr.structurally_equal(one / (sqrt(x2) * sqrt(x2) * sqrt(x2))) + + # Cube root + pow = PsMathFunction(MathFunctions.Pow) + expr = freeze(x ** sp.Rational(1, 3)) + assert expr.structurally_equal(pow(x2, freeze(sp.Rational(1, 3)))) + + # Unknown exponent + expr = freeze(x**y) + assert expr.structurally_equal(pow(x2, y2)) + + def test_tuple_array_literals(): ctx = KernelCreationContext() freeze = FreezeExpressions(ctx) diff --git a/tests/runtime/test_boundary.py b/tests/runtime/test_boundary.py index a94d3782020cd494a4b01009f04016e889fca9c0..fb8f827e88106fd3b7a45b25f9c962f4ebb14f14 100644 --- a/tests/runtime/test_boundary.py +++ b/tests/runtime/test_boundary.py @@ -8,7 +8,7 @@ import pystencils from pystencils import Assignment, create_kernel from pystencils.boundaries import BoundaryHandling, Dirichlet, Neumann, add_neumann_boundary from pystencils.datahandling import SerialDataHandling -from pystencils.enums import Target +from pystencils import Target from pystencils.slicing import slice_from_direction from pystencils.timeloop import TimeLoop diff --git a/tests/runtime/test_data/.gitignore b/tests/runtime/test_data/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..278bd1136cc562e5a67e12655d7d22e774727a72 --- /dev/null +++ b/tests/runtime/test_data/.gitignore @@ -0,0 +1 @@ +datahandling_save_test* \ No newline at end of file diff --git a/tests/runtime/test_data/datahandling_parallel_save_test/dst.dat b/tests/runtime/test_data/datahandling_parallel_save_test/dst.dat deleted file mode 100644 index 204552486f77c485a4dd333a10eff82f9d44aa9f..0000000000000000000000000000000000000000 Binary files a/tests/runtime/test_data/datahandling_parallel_save_test/dst.dat and /dev/null differ diff --git a/tests/runtime/test_data/datahandling_parallel_save_test/src.dat b/tests/runtime/test_data/datahandling_parallel_save_test/src.dat deleted file mode 100644 index 204552486f77c485a4dd333a10eff82f9d44aa9f..0000000000000000000000000000000000000000 Binary files a/tests/runtime/test_data/datahandling_parallel_save_test/src.dat and /dev/null differ diff --git a/tests/runtime/test_data/datahandling_save_test.npz b/tests/runtime/test_data/datahandling_save_test.npz deleted file mode 100644 index 486c7ee74d4421d563c3b1c2e3739d8db6308b07..0000000000000000000000000000000000000000 Binary files a/tests/runtime/test_data/datahandling_save_test.npz and /dev/null differ diff --git a/tests/runtime/test_datahandling.py b/tests/runtime/test_datahandling.py index c73ec829d53112ef917e233de7f39919ae9d6b41..9d7ff924e8d7eba9039f8f0796145bd7de116ef5 100644 --- a/tests/runtime/test_datahandling.py +++ b/tests/runtime/test_datahandling.py @@ -7,7 +7,7 @@ import numpy as np import pystencils as ps from pystencils import create_data_handling, create_kernel from pystencils.gpu.gpu_array_handler import GPUArrayHandler -from pystencils.enums import Target +from pystencils import Target try: import pytest