import pytest
import numpy as np
import sympy as sp

import pystencils as ps
from pystencils.sympyextensions import reduced_assign


@pytest.mark.parametrize('dtype', ["float64"])
@pytest.mark.parametrize("op", ["+", "-", "*", "min", "max"])
def test_reduction(dtype, op):
    x = ps.fields(f'x: {dtype}[1d]')
    w = sp.Symbol("w")

    # kernel with reduction assignment

    reduction_assignment = reduced_assign(w, op, x.center())

    config = ps.CreateKernelConfig(cpu_openmp=True)

    ast_reduction = ps.create_kernel([reduction_assignment], config, default_dtype=dtype)
    #code_reduction = ps.get_code_str(ast_reduction)
    kernel_reduction = ast_reduction.compile()

    ps.show_code(ast_reduction)

    array = np.ones((10,), dtype=dtype)
    kernel_reduction(x=array, w=0)
    # TODO: check if "w = #points"