Skip to content
Snippets Groups Projects
Commit 3b9860fb authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'bauerd/fix-vector-set' into 'v2.0-dev'

Fix order of arguments in x86 set intrinsics

See merge request !450
parents 4876f0b7 d3a713b9
No related branches found
No related tags found
1 merge request!450Fix order of arguments in x86 set intrinsics
Pipeline #74202 passed
......@@ -153,7 +153,7 @@ class X86VectorCpu(GenericVectorCpu):
)
values = [PsConstantExpr(PsConstant(v, stype)) for v in c.value]
return set_func(*values)
return set_func(*values[::-1])
def op_intrinsic(
self, expr: PsExpression, operands: Sequence[PsExpression]
......@@ -202,7 +202,7 @@ class X86VectorCpu(GenericVectorCpu):
opstr = expr.function.func.function_name
if vtype.width > 256:
raise MaterializationError("512bit ceil/floor require SVML.")
case MathFunctions.Sqrt if vtype.is_float():
opstr = expr.function.name
......
......@@ -6,6 +6,7 @@ from itertools import chain
from functools import partial
from typing import Callable
from pystencils import DEFAULTS
from pystencils.backend.kernelcreation import (
KernelCreationContext,
AstFactory,
......@@ -21,7 +22,6 @@ from pystencils.backend.transformations import (
from pystencils.backend.constants import PsConstant
from pystencils.codegen.driver import create_cpu_kernel_function
from pystencils.jit import LegacyCpuJit
from pystencils import Target, fields, Assignment, Field
from pystencils.field import create_numpy_array_with_layout
from pystencils.types import PsScalarType, PsIntegerType
......@@ -38,7 +38,9 @@ class VectorTestSetup:
@property
def name(self) -> str:
return f"{self.target.name}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}"
return (
f"{self.target.name}/{self.numeric_dtype}<{self.lanes}>/{self.index_dtype}"
)
def get_setups(target: Target) -> list[VectorTestSetup]:
......@@ -71,7 +73,9 @@ def get_setups(target: Target) -> list[VectorTestSetup]:
]
case Target.X86_AVX512_FP16:
avx512_platform = partial(X86VectorCpu, vector_arch=X86VectorArch.AVX512_FP16)
avx512_platform = partial(
X86VectorCpu, vector_arch=X86VectorArch.AVX512_FP16
)
return [
VectorTestSetup(target, avx512_platform, 8, Fp(16), SInt(32)),
VectorTestSetup(target, avx512_platform, 16, Fp(16), SInt(32)),
......@@ -187,7 +191,7 @@ def test_update_kernel(vectorization_setup: VectorTestSetup, ghost_layers: int):
resolution = np.finfo(setup.numeric_dtype.numpy_dtype).resolution
gls = ghost_layers
np.testing.assert_allclose(
dst_arr[gls:-gls, gls:-gls, :],
check_arr[gls:-gls, gls:-gls, :],
......@@ -242,3 +246,24 @@ def test_only_trailing_iterations(vectorization_setup: VectorTestSetup):
kernel(f=f_arr)
np.testing.assert_equal(f_arr, 2.0)
def test_set(vectorization_setup: VectorTestSetup):
setup = vectorization_setup
f = fields(f"f(1): {setup.index_dtype}[1D]", layout="fzyx")
update = [Assignment(f(0), DEFAULTS.spatial_counters[0])]
kernel = create_vector_kernel(update, f, setup)
shape = (23, 1)
f_arr = create_numpy_array_with_layout(
shape, layout=(1, 0), dtype=setup.index_dtype.numpy_dtype
)
f_arr[:] = 42
kernel(f=f_arr)
reference = np.array(range(shape[0])).reshape(shape)
np.testing.assert_equal(f_arr, reference)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment