Skip to content
Snippets Groups Projects
test_functions.py 9.78 KiB
import sympy as sp
import numpy as np
import pytest
from dataclasses import replace
from itertools import product

from pystencils import (
    fields,
    create_kernel,
    CreateKernelConfig,
    Target,
    Assignment,
    Field,
)
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.expressions import PsCall


def unary_function(name, xp):
    return {
        "exp": (sp.exp, xp.exp),
        "log": (sp.log, xp.log),
        "sin": (sp.sin, xp.sin),
        "cos": (sp.cos, xp.cos),
        "tan": (sp.tan, xp.tan),
        "sinh": (sp.sinh, xp.sinh),
        "cosh": (sp.cosh, xp.cosh),
        "asin": (sp.asin, xp.arcsin),
        "acos": (sp.acos, xp.arccos),
        "atan": (sp.atan, xp.arctan),
        "sqrt": (sp.sqrt, xp.sqrt),
        "abs": (sp.Abs, xp.abs),
        "floor": (sp.floor, xp.floor),
        "ceil": (sp.ceiling, xp.ceil),
    }[name]


def binary_function(name, xp):
    return {
        "min": (sp.Min, xp.fmin),
        "max": (sp.Max, xp.fmax),
        "pow": (sp.Pow, xp.power),
        "atan2": (sp.atan2, xp.arctan2),
    }[name]


AVAIL_TARGETS = Target.available_targets()


@pytest.fixture
def function_domain(function_name, dtype):
    eps = 1e-6
    rng = np.random.default_rng()

    match function_name:
        case (
            "exp" | "sin" | "cos" | "sinh" | "cosh" | "atan" | "abs" | "floor" | "ceil"
        ):
            return np.concatenate(
                [
                    [0.0, -1.0, 1.0],
                    rng.uniform(-0.1, 0.1, 5),
                    rng.uniform(-5.0, 5.0, 8),
                    rng.uniform(-10, 10, 8),
                ]
            ).astype(dtype)
        case "tan":
            return np.concatenate(
                [
                    [0.0, -1.0, 1.0],
                    rng.uniform(-np.pi / 2.0 + eps, np.pi / 2.0 - eps, 13),
                ]
            ).astype(dtype)
        case "asin" | "acos":
            return np.concatenate(
                [
                    [0.0, 1.0, -1.0],
                    rng.uniform(-1.0, 1.0, 13),
                ]
            ).astype(dtype)
        case "log" | "sqrt":
            return np.concatenate(
                [
                    [1.0],
                    rng.uniform(eps, 0.1, 7),
                    rng.uniform(eps, 1.0, 8),
                    rng.uniform(eps, 1e6, 8),
                ]
            ).astype(dtype)
        case "min" | "max" | "atan2":
            return np.concatenate(
                [
                    rng.uniform(-0.1, 0.1, 8),
                    rng.uniform(-5.0, 5.0, 8),
                    rng.uniform(-10, 10, 8),
                ]
            ).astype(dtype), np.concatenate(
                [
                    rng.uniform(-0.1, 0.1, 8),
                    rng.uniform(-5.0, 5.0, 8),
                    rng.uniform(-10, 10, 8),
                ]
            ).astype(
                dtype
            )
        case "pow":
            return np.concatenate(
                [
                    [0., 1., 1.],
                    rng.uniform(-1., 1., 8),
                    rng.uniform(0., 5., 8),
                ]
            ).astype(dtype), np.concatenate(
                [
                    [1., 0., 2.],
                    np.arange(2., 10., 1.),
                    rng.uniform(-2.0, 2.0, 8),
                ]
            ).astype(
                dtype
            )
        case _:
            assert False, "I don't know the domain of that function"


@pytest.mark.parametrize(
    "function_name, target",
    list(
        product(
            (
                "exp",
                "log",
                "sin",
                "cos",
                "tan",
                "sinh",
                "cosh",
                "asin",
                "acos",
                "atan",
            ),
            [t for t in AVAIL_TARGETS if Target._X86 not in t],
        )
    )
    + list(
        product(
            ["floor", "ceil"], [t for t in AVAIL_TARGETS if Target._AVX512 not in t]
        )
    )
    + list(product(["sqrt", "abs"], AVAIL_TARGETS)),
)
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_unary_functions(gen_config, xp, function_name, dtype, function_domain):
    sp_func, xp_func = unary_function(function_name, xp)
    resolution = np.finfo(dtype).resolution

    #   Array size should be larger than eight, such that vectorized kernels don't just run their remainder loop
    inp = xp.array(function_domain)
    outp = xp.zeros_like(inp)

    reference = xp_func(inp)

    inp_field = Field.create_from_numpy_array("inp", inp)
    outp_field = inp_field.new_field_with_different_name("outp")

    asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))]
    gen_config = replace(gen_config, default_dtype=dtype)

    kernel = create_kernel(asms, gen_config)
    kfunc = kernel.compile()
    kfunc(inp=inp, outp=outp)

    xp.testing.assert_allclose(outp, reference, rtol=resolution)


@pytest.mark.parametrize(
    "function_name,target",
    list(product(["min", "max"], AVAIL_TARGETS))
    + list(
        product(["pow", "atan2"], [t for t in AVAIL_TARGETS if Target._X86 not in t])
    ),
)
@pytest.mark.parametrize("dtype", (np.float32, np.float64))
def test_binary_functions(gen_config, xp, function_name, dtype, function_domain):
    sp_func, xp_func = binary_function(function_name, xp)
    resolution: dtype = np.finfo(dtype).resolution

    inp = xp.array(function_domain[0])
    inp2 = xp.array(function_domain[1])
    outp = xp.zeros_like(inp)

    reference = xp_func(inp, inp2)

    inp_field = Field.create_from_numpy_array("inp", inp)
    inp2_field = Field.create_from_numpy_array("inp2", inp)
    outp_field = inp_field.new_field_with_different_name("outp")

    asms = [
        Assignment(
            outp_field.center(), sp_func(inp_field.center(), inp2_field.center())
        )
    ]
    gen_config = replace(gen_config, default_dtype=dtype)

    kernel = create_kernel(asms, gen_config)
    kfunc = kernel.compile()
    kfunc(inp=inp, inp2=inp2, outp=outp)

    xp.testing.assert_allclose(outp, reference, rtol=resolution)


dtype_and_target_for_integer_funcs = pytest.mark.parametrize(
    "dtype, target",
    list(product([np.int32], [t for t in AVAIL_TARGETS if t is not Target.CUDA]))
    + list(
        product(
            [np.int64],
            [
                t
                for t in AVAIL_TARGETS
                if t not in (Target.X86_SSE, Target.X86_AVX, Target.CUDA)
            ],
        )
    ),
)


@dtype_and_target_for_integer_funcs
def test_integer_abs(gen_config, xp, dtype):
    sp_func, xp_func = unary_function("abs", xp)

    smallest = np.iinfo(dtype).min
    largest = np.iinfo(dtype).max

    inp = xp.array([-1, 0, 1, 3, -5, -312, smallest + 1, largest], dtype=dtype)

    outp = xp.zeros_like(inp)
    reference = xp_func(inp)

    inp_field = Field.create_from_numpy_array("inp", inp)
    outp_field = inp_field.new_field_with_different_name("outp")

    asms = [Assignment(outp_field.center(), sp_func(inp_field.center()))]
    gen_config = replace(gen_config, default_dtype=dtype)

    kernel = create_kernel(asms, gen_config)
    kfunc = kernel.compile()
    kfunc(inp=inp, outp=outp)

    xp.testing.assert_array_equal(outp, reference)


@pytest.mark.parametrize("function_name", ("min", "max"))
@dtype_and_target_for_integer_funcs
def test_integer_binary_functions(gen_config, xp, function_name, dtype):
    sp_func, xp_func = binary_function(function_name, xp)

    smallest = np.iinfo(dtype).min
    largest = np.iinfo(dtype).max

    inp1 = xp.array([-1, 0, 1, 3, -5, -312, smallest + 1, largest], dtype=dtype)
    inp2 = xp.array([3, -5, 1, 12, 1, 11, smallest + 42, largest - 3], dtype=dtype)

    outp = xp.zeros_like(inp1)
    reference = xp_func(inp1, inp2)

    inp_field = Field.create_from_numpy_array("inp1", inp1)
    inp2_field = Field.create_from_numpy_array("inp2", inp2)
    outp_field = inp_field.new_field_with_different_name("outp")

    asms = [
        Assignment(
            outp_field.center(), sp_func(inp_field.center(), inp2_field.center())
        )
    ]
    gen_config = replace(gen_config, default_dtype=dtype)

    kernel = create_kernel(asms, gen_config)
    kfunc = kernel.compile()
    kfunc(inp1=inp1, inp2=inp2, outp=outp)

    xp.testing.assert_array_equal(outp, reference)


@pytest.mark.parametrize("a", [sp.Symbol("a"), fields("a:  float64[2d]").center])
def test_avoid_pow(a):
    x = fields("x:  float64[2d]")

    up = Assignment(x.center_vector[0], 2 * a**2 / 3)
    func = create_kernel(up)

    powers = list(
        dfs_preorder(
            func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
        )
    )
    assert not powers


@pytest.mark.xfail(reason="fast_div not available yet")
def test_avoid_pow_fast_div():
    x = fields("x:  float64[2d]")
    a = fields("a:  float64[2d]").center

    up = Assignment(x.center_vector[0], fast_division(1, (a**2)))
    func = create_kernel(up, config=CreateKernelConfig(target=Target.GPU))

    powers = list(
        dfs_preorder(
            func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
        )
    )
    assert not powers


def test_avoid_pow_move_constants():
    # At the end of the kernel creation the function move_constants_before_loop will be called
    # This function additionally contains substitutions for symbols with the same value
    # Thus it simplifies the equations again
    x = fields("x:  float64[2d]")
    a, b, c = sp.symbols("a, b, c")

    up = [
        Assignment(a, 0.0),
        Assignment(b, 0.0),
        Assignment(c, 0.0),
        Assignment(
            x.center_vector[0],
            a**2 / 18 - a * b / 6 - a / 18 + b**2 / 18 + b / 18 - c**2 / 36,
        ),
    ]
    func = create_kernel(up)

    powers = list(
        dfs_preorder(
            func.body, lambda n: isinstance(n, PsCall) and "pow" in n.function.name
        )
    )
    assert not powers