Skip to content
Snippets Groups Projects
Commit 3f32ceca authored by Frederik Hennig's avatar Frederik Hennig
Browse files

adapt JIT compilers to changed kernelfunction interface. Fix sparse index translations.

parent 2f637986
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
Pipeline #69763 failed
......@@ -266,7 +266,7 @@ class PsBufferAcc(PsLvalue, PsExpression):
return PsBufferAcc(self._base_ptr.symbol, [i.clone() for i in self._index])
def __repr__(self) -> str:
return f"PsArrayAccess({repr(self._base_ptr)}, {repr(self._index)})"
return f"PsBufferAcc({repr(self._base_ptr)}, {repr(self._index)})"
class PsSubscript(PsLvalue, PsExpression):
......
......@@ -13,11 +13,8 @@ from ..exceptions import PsInternalCompilerError
from ..kernelfunction import (
KernelFunction,
KernelParameter,
FieldParameter,
FieldShapeParam,
FieldStrideParam,
FieldPointerParam,
)
from ..properties import FieldBasePtr, FieldShape, FieldStride
from ..constraints import KernelParamsConstraint
from ...types import (
PsType,
......@@ -209,7 +206,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
self._array_extractions: dict[Field, str] = dict()
self._array_frees: dict[Field, str] = dict()
self._array_assoc_var_extractions: dict[FieldParameter, str] = dict()
self._array_assoc_var_extractions: dict[KernelParameter, str] = dict()
self._scalar_extractions: dict[KernelParameter, str] = dict()
self._constraint_checks: list[str] = []
......@@ -282,31 +279,34 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return param.name
def extract_array_assoc_var(self, param: FieldParameter) -> str:
def extract_array_assoc_var(self, param: KernelParameter) -> str:
if param not in self._array_assoc_var_extractions:
field = param.field
field = param.fields.pop()
buffer = self.extract_field(field)
match param:
case FieldPointerParam():
code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;"
case FieldShapeParam():
coord = param.coordinate
code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];"
case FieldStrideParam():
coord = param.coordinate
code = (
f"{param.dtype} {param.name} = "
f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
)
case _:
assert False, "unreachable code"
code: str | None = None
for prop in param.properties:
match prop:
case FieldBasePtr():
code = f"{param.dtype} {param.name} = ({param.dtype}) {buffer}.buf;"
break
case FieldShape(_, coord):
code = f"{param.dtype} {param.name} = {buffer}.shape[{coord}];"
break
case FieldStride(_, coord):
code = (
f"{param.dtype} {param.name} = "
f"{buffer}.strides[{coord}] / {field.dtype.itemsize};"
)
break
assert code is not None
self._array_assoc_var_extractions[param] = code
return param.name
def extract_parameter(self, param: KernelParameter):
if isinstance(param, FieldParameter):
if param.is_field_parameter:
self.extract_array_assoc_var(param)
else:
self.extract_scalar(param)
......
......@@ -16,11 +16,9 @@ from .jit import JitBase, JitError, KernelWrapper
from ..kernelfunction import (
KernelFunction,
GpuKernelFunction,
FieldPointerParam,
FieldShapeParam,
FieldStrideParam,
KernelParameter,
)
from ..properties import FieldShape, FieldStride, FieldBasePtr
from ..emission import emit_code
from ...types import PsStructType
......@@ -98,8 +96,8 @@ class CupyKernelWrapper(KernelWrapper):
field_shapes = set()
index_shapes = set()
def check_shape(field_ptr: FieldPointerParam, arr: cp.ndarray):
field = field_ptr.field
def check_shape(field_ptr: KernelParameter, arr: cp.ndarray):
field = field_ptr.fields.pop()
if field.has_fixed_shape:
expected_shape = tuple(int(s) for s in field.shape)
......@@ -118,7 +116,7 @@ class CupyKernelWrapper(KernelWrapper):
if isinstance(field.dtype, PsStructType):
assert expected_strides[-1] == 1
expected_strides = expected_strides[:-1]
actual_strides = tuple(s // arr.dtype.itemsize for s in arr.strides)
if expected_strides != actual_strides:
raise ValueError(
......@@ -149,28 +147,38 @@ class CupyKernelWrapper(KernelWrapper):
arr: cp.ndarray
for kparam in self._kfunc.parameters:
match kparam:
case FieldPointerParam(_, dtype, field):
arr = kwargs[field.name]
if arr.dtype != field.dtype.numpy_dtype:
raise JitError(
f"Data type mismatch at array argument {field.name}:"
f"Expected {field.dtype}, got {arr.dtype}"
)
check_shape(kparam, arr)
args.append(arr)
case FieldShapeParam(name, dtype, field, coord):
arr = kwargs[field.name]
add_arg(name, arr.shape[coord], dtype)
case FieldStrideParam(name, dtype, field, coord):
arr = kwargs[field.name]
add_arg(name, arr.strides[coord] // arr.dtype.itemsize, dtype)
case KernelParameter(name, dtype):
val: Any = kwargs[name]
add_arg(name, val, dtype)
if kparam.is_field_parameter:
# Determine field-associated data to pass in
for prop in kparam.properties:
match prop:
case FieldBasePtr(field):
arr = kwargs[field.name]
if arr.dtype != field.dtype.numpy_dtype:
raise JitError(
f"Data type mismatch at array argument {field.name}:"
f"Expected {field.dtype}, got {arr.dtype}"
)
check_shape(kparam, arr)
args.append(arr)
break
case FieldShape(field, coord):
arr = kwargs[field.name]
add_arg(kparam.name, arr.shape[coord], kparam.dtype)
break
case FieldStride(field, coord):
arr = kwargs[field.name]
add_arg(
kparam.name,
arr.strides[coord] // arr.dtype.itemsize,
kparam.dtype,
)
break
else:
# scalar parameter
val: Any = kwargs[kparam.name]
add_arg(kparam.name, val, kparam.dtype)
# Determine launch grid
from ..ast.expressions import evaluate_expression
......
......@@ -83,9 +83,9 @@ class KernelParameter:
return set(p.field for p in filter(lambda p: isinstance(p, _FieldProperty), self.properties)) # type: ignore
def get_properties(
self, prop_type: type[PsSymbolProperty]
self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...]
) -> set[PsSymbolProperty]:
"""Retrieve all properties of the given type attached to this parameter"""
"""Retrieve all properties of the given type(s) attached to this parameter"""
return set(filter(lambda p: isinstance(p, prop_type), self._properties))
@property
......@@ -94,11 +94,6 @@ class KernelParameter:
@property
def is_field_parameter(self) -> bool:
warn(
"`is_field_parameter` is deprecated and will be removed in a future version of pystencils. "
"Check `param.fields` for emptiness instead.",
DeprecationWarning,
)
return bool(self.fields)
@property
......
......@@ -7,6 +7,7 @@ from ..kernelcreation import (
IterationSpace,
FullIterationSpace,
SparseIterationSpace,
AstFactory
)
from ..kernelcreation.context import KernelCreationContext
......@@ -159,6 +160,7 @@ class CudaPlatform(GenericGpu):
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
factory = AstFactory(self._ctx)
ispace.sparse_counter.dtype = constify(ispace.sparse_counter.get_dtype())
sparse_ctr = PsExpression.make(ispace.sparse_counter)
......@@ -173,7 +175,7 @@ class CudaPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr,),
(sparse_ctr, factory.parse_index(0)),
),
coord.name,
),
......
......@@ -124,13 +124,15 @@ class GenericCpu(Platform):
return PsBlock([loops])
def _create_sparse_loop(self, body: PsBlock, ispace: SparseIterationSpace):
factory = AstFactory(self._ctx)
mappings = [
PsDeclaration(
PsSymbolExpr(ctr),
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(PsExpression.make(ispace.sparse_counter),),
(PsExpression.make(ispace.sparse_counter), factory.parse_index(0)),
),
coord.name,
),
......
......@@ -20,7 +20,7 @@ from ..ast.expressions import (
)
from ..extensions.cpp import CppMethodCall
from ..kernelcreation.context import KernelCreationContext
from ..kernelcreation import KernelCreationContext, AstFactory
from ..constants import PsConstant
from .generic_gpu import GenericGpu, GpuThreadsRange
from ..exceptions import MaterializationError
......@@ -147,6 +147,8 @@ class SyclPlatform(GenericGpu):
def _prepend_sparse_translation(
self, body: PsBlock, ispace: SparseIterationSpace
) -> tuple[PsBlock, GpuThreadsRange]:
factory = AstFactory(self._ctx)
id_type = PsCustomType("sycl::id< 1 >", const=True)
id_symbol = PsExpression.make(self._ctx.get_symbol("id", id_type))
......@@ -165,7 +167,7 @@ class SyclPlatform(GenericGpu):
PsLookup(
PsBufferAcc(
ispace.index_list.base_pointer,
(sparse_ctr,),
(sparse_ctr, factory.parse_index(0)),
),
coord.name,
),
......
......@@ -12,7 +12,7 @@ from pystencils.types import PsIntegerType
from pystencils.types.quick import Arr, SInt
from pystencils.gpu.gpu_array_handler import GPUArrayHandler
from pystencils.field import Field, FieldType
from pystencils.backend.kernelfunction import FieldPointerParam
from pystencils.backend.properties import FieldBasePtr
try:
# noinspection PyPep8Naming
......@@ -244,9 +244,9 @@ class BoundaryHandling:
for b_obj, idx_arr in b[self._index_array_name].boundary_object_to_index_list.items():
kwargs[self._field_name] = b[self._field_name]
kwargs['indexField'] = idx_arr
data_used_in_kernel = (p.field.name
data_used_in_kernel = (p.fields.pop().name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if isinstance(p, FieldPointerParam) and p.field.name not in kwargs)
if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in kwargs)
kwargs.update({name: b[name] for name in data_used_in_kernel})
self._boundary_object_to_boundary_info[b_obj].kernel(**kwargs)
......@@ -260,9 +260,9 @@ class BoundaryHandling:
arguments = kwargs.copy()
arguments[self._field_name] = b[self._field_name]
arguments['indexField'] = idx_arr
data_used_in_kernel = (p.field.name
data_used_in_kernel = (p.fields.pop().name
for p in self._boundary_object_to_boundary_info[b_obj].kernel.parameters
if isinstance(p, FieldPointerParam) and p.field.name not in arguments)
if bool(p.get_properties(FieldBasePtr)) and p.fields.pop().name not in arguments)
arguments.update({name: b[name] for name in data_used_in_kernel if name not in arguments})
kernel = self._boundary_object_to_boundary_info[b_obj].kernel
......
......@@ -5,7 +5,8 @@ from pystencils import Field, TypedSymbol, FieldType, DynamicType
from pystencils.backend.kernelcreation import KernelCreationContext
from pystencils.backend.constants import PsConstant
from pystencils.backend.memory import PsSymbol, FieldShape, FieldStride
from pystencils.backend.memory import PsSymbol
from pystencils.backend.properties import FieldShape, FieldStride
from pystencils.backend.exceptions import KernelConstraintsError
from pystencils.types.quick import SInt, Fp
from pystencils.types import deconstify
......
......@@ -3,7 +3,7 @@ import sympy as sp
from pystencils import make_slice, Field, Assignment
from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC
from pystencils.backend.literals import PsLiteral
from pystencils.backend.emission import CAstPrinter
from pystencils.backend.ast.expressions import PsExpression, PsSubscript
......@@ -46,6 +46,9 @@ def test_literals():
hoist = HoistLoopInvariantDeclarations(ctx)
ast = hoist(ast)
lower = LowerToC(ctx)
ast = lower(ast)
assert isinstance(ast, PsBlock)
assert len(ast.statements) == 2
assert ast.statements[0] == x_decl
......
import pytest
from typing import ClassVar
from dataclasses import dataclass
from pystencils.backend.memory import PsSymbol, PsSymbolProperty
from pystencils.backend.memory import PsSymbol, PsSymbolProperty, UniqueSymbolProperty
def test_properties():
......@@ -16,9 +15,8 @@ def test_properties():
s: str
@dataclass(frozen=True)
class UniqueProperty(PsSymbolProperty):
class MyUniqueProperty(UniqueSymbolProperty):
val: int
_unique: ClassVar[bool] = True
s = PsSymbol("s")
......@@ -36,17 +34,17 @@ def test_properties():
assert s.get_properties(NumbersProperty) == {NumbersProperty(42, 8.71)}
assert not s.get_properties(UniqueProperty)
assert not s.get_properties(MyUniqueProperty)
s.add_property(UniqueProperty(13))
assert s.get_properties(UniqueProperty) == {UniqueProperty(13)}
s.add_property(MyUniqueProperty(13))
assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
# Adding the same one again does not raise
s.add_property(UniqueProperty(13))
assert s.get_properties(UniqueProperty) == {UniqueProperty(13)}
s.add_property(MyUniqueProperty(13))
assert s.get_properties(MyUniqueProperty) == {MyUniqueProperty(13)}
with pytest.raises(ValueError):
s.add_property(UniqueProperty(14))
s.add_property(MyUniqueProperty(14))
s.remove_property(UniqueProperty(13))
assert not s.get_properties(UniqueProperty)
s.remove_property(MyUniqueProperty(13))
assert not s.get_properties(MyUniqueProperty)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment