From 3f32ceca0bafb9c93e0fbcae13e2679d29d95dc2 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 22 Oct 2024 14:17:06 +0200 Subject: [PATCH] adapt JIT compilers to changed kernelfunction interface. Fix sparse index translations. --- src/pystencils/backend/ast/expressions.py | 2 +- .../backend/jit/cpu_extension_module.py | 44 ++++++------- src/pystencils/backend/jit/gpu_cupy.py | 64 +++++++++++-------- src/pystencils/backend/kernelfunction.py | 9 +-- src/pystencils/backend/platforms/cuda.py | 4 +- .../backend/platforms/generic_cpu.py | 4 +- src/pystencils/backend/platforms/sycl.py | 6 +- src/pystencils/boundaries/boundaryhandling.py | 10 +-- tests/nbackend/kernelcreation/test_context.py | 3 +- tests/nbackend/test_extensions.py | 5 +- tests/nbackend/test_memory.py | 22 +++---- 11 files changed, 92 insertions(+), 81 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 6a04f4f95..32d06b633 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -266,7 +266,7 @@ class PsBufferAcc(PsLvalue, PsExpression): return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index]) def __repr__(self) -> str: - return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._index)})" + return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})" class PsSubscript(PsLvalue, PsExpression): diff --git a/src/pystencils/backend/jit/cpu_extension_module.py b/src/pystencils/backend/jit/cpu_extension_module.py index b9b793589..dede60cba 100644 --- a/src/pystencils/backend/jit/cpu_extension_module.py +++ b/src/pystencils/backend/jit/cpu_extension_module.py @@ -13,11 +13,8 @@ from ..exceptions import PsInternalCompilerError from ..kernelfunction import ( KernelFunction, KernelParameter, - FieldParameter, - FieldShapeParam, - FieldStrideParam, - FieldPointerParam, ) +from ..properties import FieldBasePtr, FieldShape, FieldStride from ..constraints import KernelParamsConstraint from ...types import ( PsType, @@ -209,7 +206,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_extractions: dict[Field, str] = dict() self._array_frees: dict[Field, str] = dict() - self._array_assoc_var_extractions: dict[FieldParameter, str] = dict() + self._array_assoc_var_extractions: dict[KernelParameter, str] = dict() self._scalar_extractions: dict[KernelParameter, str] = dict() self._constraint_checks: list[str] = [] @@ -282,31 +279,34 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_array_assoc_var(self, param: FieldParameter) -> str: + def extract_array_assoc_var(self, param: KernelParameter) -> str: if param not in self._array_assoc_var_extractions: - field = param.field + field = param.fields.pop() buffer = self.extract_field(field) - match param: - case FieldPointerParam(): - code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" - case FieldShapeParam(): - coord = param.coordinate - code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" - case FieldStrideParam(): - coord = param.coordinate - code = ( - f"{param.dtype} {param.name} = " - f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" - ) - case _: - assert False, "unreachable code" + code: str | None = None + + for prop in param.properties: + match prop: + case FieldBasePtr(): + code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;" + break + case FieldShape(_, coord): + code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];" + break + case FieldStride(_, coord): + code = ( + f"{param.dtype} {param.name} = " + f"{buffer}.strides[{coord}] / {field.dtype.itemsize};" + ) + break + assert code is not None self._array_assoc_var_extractions[param] = code return param.name def extract_parameter(self, param: KernelParameter): - if isinstance(param, FieldParameter): + if param.is_field_parameter: self.extract_array_assoc_var(param) else: self.extract_scalar(param) diff --git a/src/pystencils/backend/jit/gpu_cupy.py b/src/pystencils/backend/jit/gpu_cupy.py index d6aaac2d2..15f5f6967 100644 --- a/src/pystencils/backend/jit/gpu_cupy.py +++ b/src/pystencils/backend/jit/gpu_cupy.py @@ -16,11 +16,9 @@ from .jit import JitBase, JitError, KernelWrapper from ..kernelfunction import ( KernelFunction, GpuKernelFunction, - FieldPointerParam, - FieldShapeParam, - FieldStrideParam, KernelParameter, ) +from ..properties import FieldShape, FieldStride, FieldBasePtr from ..emission import emit_code from ...types import PsStructType @@ -98,8 +96,8 @@ class CupyKernelWrapper(KernelWrapper): field_shapes = set() index_shapes = set() - def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray): - field = field_ptr.field + def check_shape(field_ptr: KernelParameter, arr: cp.ndarray): + field = field_ptr.fields.pop() if field.has_fixed_shape: expected_shape = tuple(int(s) for s in field.shape) @@ -118,7 +116,7 @@ class CupyKernelWrapper(KernelWrapper): if isinstance(field.dtype, PsStructType): assert expected_strides[-1] == 1 expected_strides = expected_strides[:-1] - + actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides) if expected_strides != actual_strides: raise ValueError( @@ -149,28 +147,38 @@ class CupyKernelWrapper(KernelWrapper): arr: cp.ndarray for kparam in self._kfunc.parameters: - match kparam: - case FieldPointerParam(_, dtype, field): - arr = kwargs[field.name] - if arr.dtype != field.dtype.numpy_dtype: - raise JitError( - f"Data type mismatch at array argument {field.name}:" - f"Expected {field.dtype}, got {arr.dtype}" - ) - check_shape(kparam, arr) - args.append(arr) - - case FieldShapeParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.shape[coord], dtype) - - case FieldStrideParam(name, dtype, field, coord): - arr = kwargs[field.name] - add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype) - - case KernelParameter(name, dtype): - val: Any = kwargs[name] - add_arg(name, val, dtype) + if kparam.is_field_parameter: + # Determine field-associated data to pass in + for prop in kparam.properties: + match prop: + case FieldBasePtr(field): + arr = kwargs[field.name] + if arr.dtype != field.dtype.numpy_dtype: + raise JitError( + f"Data type mismatch at array argument {field.name}:" + f"Expected {field.dtype}, got {arr.dtype}" + ) + check_shape(kparam, arr) + args.append(arr) + break + + case FieldShape(field, coord): + arr = kwargs[field.name] + add_arg(kparam.name, arr.shape[coord], kparam.dtype) + break + + case FieldStride(field, coord): + arr = kwargs[field.name] + add_arg( + kparam.name, + arr.strides[coord] // arr.dtype.itemsize, + kparam.dtype, + ) + break + else: + # scalar parameter + val: Any = kwargs[kparam.name] + add_arg(kparam.name, val, kparam.dtype) # Determine launch grid from ..ast.expressions import evaluate_expression diff --git a/src/pystencils/backend/kernelfunction.py b/src/pystencils/backend/kernelfunction.py index a5bdab623..da0b59e8f 100644 --- a/src/pystencils/backend/kernelfunction.py +++ b/src/pystencils/backend/kernelfunction.py @@ -83,9 +83,9 @@ class KernelParameter: return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore def get_properties( - self, prop_type: type[PsSymbolProperty] + self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] ) -> set[PsSymbolProperty]: - """Retrieve all properties of the given type attached to this parameter""" + """Retrieve all properties of the given type(s) attached to this parameter""" return set(filter(lambda p: isinstance(p, prop_type), self._properties)) @property @@ -94,11 +94,6 @@ class KernelParameter: @property def is_field_parameter(self) -> bool: - warn( - "`is_field_parameter` is deprecated and will be removed in a future version of pystencils. " - "Check `param.fields` for emptiness instead.", - DeprecationWarning, - ) return bool(self.fields) @property diff --git a/src/pystencils/backend/platforms/cuda.py b/src/pystencils/backend/platforms/cuda.py index 6100a371b..323dcc5a9 100644 --- a/src/pystencils/backend/platforms/cuda.py +++ b/src/pystencils/backend/platforms/cuda.py @@ -7,6 +7,7 @@ from ..kernelcreation import ( IterationSpace, FullIterationSpace, SparseIterationSpace, + AstFactory ) from ..kernelcreation.context import KernelCreationContext @@ -159,6 +160,7 @@ class CudaPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype()) sparse_ctr = PsExpression.make(ispace.sparse_counter) @@ -173,7 +175,7 @@ class CudaPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - (sparse_ctr,), + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/generic_cpu.py b/src/pystencils/backend/platforms/generic_cpu.py index 95aaf50c4..f8cae89fc 100644 --- a/src/pystencils/backend/platforms/generic_cpu.py +++ b/src/pystencils/backend/platforms/generic_cpu.py @@ -124,13 +124,15 @@ class GenericCpu(Platform): return PsBlock([loops]) def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace): + factory = AstFactory(self._ctx) + mappings = [ PsDeclaration( PsSymbolExpr(ctr), PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - (PsExpression.make(ispace.sparse_counter),), + (PsExpression.make(ispace.sparse_counter), factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/backend/platforms/sycl.py b/src/pystencils/backend/platforms/sycl.py index b8684ce22..ec5e7eda0 100644 --- a/src/pystencils/backend/platforms/sycl.py +++ b/src/pystencils/backend/platforms/sycl.py @@ -20,7 +20,7 @@ from ..ast.expressions import ( ) from ..extensions.cpp import CppMethodCall -from ..kernelcreation.context import KernelCreationContext +from ..kernelcreation import KernelCreationContext, AstFactory from ..constants import PsConstant from .generic_gpu import GenericGpu, GpuThreadsRange from ..exceptions import MaterializationError @@ -147,6 +147,8 @@ class SyclPlatform(GenericGpu): def _prepend_sparse_translation( self, body: PsBlock, ispace: SparseIterationSpace ) -> tuple[PsBlock, GpuThreadsRange]: + factory = AstFactory(self._ctx) + id_type = PsCustomType("sycl::id< 1 >", const=True) id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type)) @@ -165,7 +167,7 @@ class SyclPlatform(GenericGpu): PsLookup( PsBufferAcc( ispace.index_list.base_pointer, - (sparse_ctr,), + (sparse_ctr, factory.parse_index(0)), ), coord.name, ), diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index c7657ec51..52ded8ab2 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -12,7 +12,7 @@ from pystencils.types import PsIntegerType from pystencils.types.quick import Arr, SInt from pystencils.gpu.gpu_array_handler import GPUArrayHandler from pystencils.field import Field, FieldType -from pystencils.backend.kernelfunction import FieldPointerParam +from pystencils.backend.properties import FieldBasePtr try: # noinspection PyPep8Naming @@ -244,9 +244,9 @@ class BoundaryHandling: for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items(): kwargs[self._field_name] = b[self._field_name] kwargs['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields.pop().name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in kwargs) + if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in kwargs) kwargs.update({name: b[name] for name in data_used_in_kernel}) self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs) @@ -260,9 +260,9 @@ class BoundaryHandling: arguments = kwargs.copy() arguments[self._field_name] = b[self._field_name] arguments['indexField'] = idx_arr - data_used_in_kernel = (p.field.name + data_used_in_kernel = (p.fields.pop().name for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters - if isinstance(p, FieldPointerParam) and p.field.name not in arguments) + if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in arguments) arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments}) kernel = self._boundary_object_to_boundary_info[b_obj].kernel diff --git a/tests/nbackend/kernelcreation/test_context.py b/tests/nbackend/kernelcreation/test_context.py index ff766e6b5..384fc9315 100644 --- a/tests/nbackend/kernelcreation/test_context.py +++ b/tests/nbackend/kernelcreation/test_context.py @@ -5,7 +5,8 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType from pystencils.backend.kernelcreation import KernelCreationContext from pystencils.backend.constants import PsConstant -from pystencils.backend.memory import PsSymbol, FieldShape, FieldStride +from pystencils.backend.memory import PsSymbol +from pystencils.backend.properties import FieldShape, FieldStride from pystencils.backend.exceptions import KernelConstraintsError from pystencils.types.quick import SInt, Fp from pystencils.types import deconstify diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py index 914d05594..b1403185c 100644 --- a/tests/nbackend/test_extensions.py +++ b/tests/nbackend/test_extensions.py @@ -3,7 +3,7 @@ import sympy as sp from pystencils import make_slice, Field, Assignment from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace -from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC from pystencils.backend.literals import PsLiteral from pystencils.backend.emission import CAstPrinter from pystencils.backend.ast.expressions import PsExpression, PsSubscript @@ -46,6 +46,9 @@ def test_literals(): hoist = HoistLoopInvariantDeclarations(ctx) ast = hoist(ast) + lower = LowerToC(ctx) + ast = lower(ast) + assert isinstance(ast, PsBlock) assert len(ast.statements) == 2 assert ast.statements[0] == x_decl diff --git a/tests/nbackend/test_memory.py b/tests/nbackend/test_memory.py index fb2ab340e..5841e0f4f 100644 --- a/tests/nbackend/test_memory.py +++ b/tests/nbackend/test_memory.py @@ -1,8 +1,7 @@ import pytest -from typing import ClassVar from dataclasses import dataclass -from pystencils.backend.memory import PsSymbol, PsSymbolProperty +from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty def test_properties(): @@ -16,9 +15,8 @@ def test_properties(): s: str @dataclass(frozen=True) - class UniqueProperty(PsSymbolProperty): + class MyUniqueProperty(UniqueSymbolProperty): val: int - _unique: ClassVar[bool] = True s = PsSymbol("s") @@ -36,17 +34,17 @@ def test_properties(): assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)} - assert not s.get_properties(UniqueProperty) + assert not s.get_properties(MyUniqueProperty) - s.add_property(UniqueProperty(13)) - assert s.get_properties(UniqueProperty) == {UniqueProperty(13)} + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} # Adding the same one again does not raise - s.add_property(UniqueProperty(13)) - assert s.get_properties(UniqueProperty) == {UniqueProperty(13)} + s.add_property(MyUniqueProperty(13)) + assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)} with pytest.raises(ValueError): - s.add_property(UniqueProperty(14)) + s.add_property(MyUniqueProperty(14)) - s.remove_property(UniqueProperty(13)) - assert not s.get_properties(UniqueProperty) + s.remove_property(MyUniqueProperty(13)) + assert not s.get_properties(MyUniqueProperty) -- GitLab