diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 2f46a74211f28dcdbf1902803155bed7b9bcc530..868a7852caa416ff0dcd83212a2ab0cbf96ebe3c 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -1,16 +1,17 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Iterable, Iterator, Any from itertools import chain, count from collections import namedtuple, defaultdict import re +from ..ast.expressions import PsExpression from ...defaults import DEFAULTS from ...field import Field, FieldType +from ...sympyextensions import ReductionOp from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType -from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable - from ..memory import PsSymbol, PsBuffer from ..constants import PsConstant from ...types import ( @@ -46,6 +47,16 @@ class FieldsInKernel: FieldArrayPair = namedtuple("FieldArrayPair", ("field", "array")) +@dataclass(frozen=True) +class ReductionInfo: + + op: ReductionOp + init_val: PsExpression + + orig_symbol: PsSymbol + ptr_symbol: PsSymbol + + class KernelCreationContext: """Manages the translation process from the SymPy frontend to the backend AST, and collects all necessary information for the translation: @@ -77,8 +88,7 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) - self._local_reduction_symbols: dict[PsSymbol, LocalReductionVariable] = dict() - self._reduction_ptr_symbols: dict[PsSymbol, ReductionPointerVariable] = dict() + self._symbols_reduction_info: dict[PsSymbol, ReductionInfo] = dict() self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() @@ -173,41 +183,17 @@ class KernelCreationContext: self._symbols[old.name] = new - def add_local_reduction_symbol(self, local_symb: PsSymbol, local_var_prop: LocalReductionVariable): - """Adds entry for a symbol and its property to the lookup table for local reduction variables. + def add_symbol_reduction_info(self, local_symb: PsSymbol, reduction_info: ReductionInfo): + """Adds entry for a symbol and its reduction info to its corresponding lookup table. - The symbol ``symbol`` should not have a 'LocalReductionSymbol' property and shall not exist in the symbol table. + The symbol ``symbol`` shall not exist in the symbol table already. """ - if self.find_symbol(local_symb.name) is not None: - raise PsInternalCompilerError( - f"add_local_reduction_symbol: {local_symb.name} already exist in the symbol table" - ) - self.add_symbol(local_symb) - - if local_symb not in self._local_reduction_symbols and not local_symb.get_properties(LocalReductionVariable): - local_symb.add_property(local_var_prop) - self._local_reduction_symbols[local_symb] = local_var_prop - else: + if local_symb in self._symbols_reduction_info: raise PsInternalCompilerError( - f"add_local_reduction_symbol: {local_symb.name} already exists in local reduction table" + f"add_symbol_reduction_info: {local_symb.name} already exist in the symbol table" ) - def add_reduction_ptr_symbol(self, orig_symb: PsSymbol, ptr_symb: PsSymbol, ptr_var_prop: ReductionPointerVariable): - """Replaces reduction symbol with a pointer-based counterpart used for export - and adds the new symbol and its property to the lookup table for pointer-based reduction variables - - The symbol ``ptr_symbol`` should not exist in the symbol table. - """ - self.replace_symbol(orig_symb, ptr_symb) - - if ptr_symb not in self._reduction_ptr_symbols and not ptr_symb.get_properties( - ReductionPointerVariable): - ptr_symb.add_property(ptr_var_prop) - self._reduction_ptr_symbols[ptr_symb] = ptr_var_prop - else: - raise PsInternalCompilerError( - f"add_reduction_ptr_symbol: {ptr_symb.name} already exists in pointer-based reduction variable table " - ) + self._symbols_reduction_info[local_symb] = reduction_info def duplicate_symbol( self, symb: PsSymbol, new_dtype: PsType | None = None @@ -245,14 +231,9 @@ class KernelCreationContext: return self._symbols.values() @property - def local_reduction_symbols(self) -> dict[PsSymbol, LocalReductionVariable]: + def symbols_reduction_info(self) -> dict[PsSymbol, ReductionInfo]: """Return a dictionary holding kernel-local reduction symbols and their reduction properties.""" - return self._local_reduction_symbols - - @property - def reduction_pointer_symbols(self) -> dict[PsSymbol, ReductionPointerVariable]: - """Return a dictionary holding pointer-based reduction symbols and their reduction properties.""" - return self._reduction_ptr_symbols + return self._symbols_reduction_info # Fields and Arrays diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 4bf136562458f05c4d34776316ab86e2483bce52..5bb7f8b088a04e9c9d484896946be7029ab24d0f 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -19,7 +19,7 @@ from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.reduction import ReductionAssignment, ReductionOp from ...field import Field, FieldType -from .context import KernelCreationContext +from .context import KernelCreationContext, ReductionInfo from ..ast.structural import ( PsAstNode, @@ -66,8 +66,6 @@ from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions from ..exceptions import FreezeError -from ...codegen.properties import LocalReductionVariable, ReductionPointerVariable - ExprLike = ( sp.Expr @@ -210,25 +208,25 @@ class FreezeExpressions: init_val: PsExpression match expr.op: case ReductionOp.Add: - init_val = PsConstantExpr(PsConstant(0, dtype)) + init_val = PsConstantExpr(PsConstant(0)) case ReductionOp.Sub: - init_val = PsConstantExpr(PsConstant(0, dtype)) + init_val = PsConstantExpr(PsConstant(0)) case ReductionOp.Mul: - init_val = PsConstantExpr(PsConstant(1, dtype)) + init_val = PsConstantExpr(PsConstant(1)) case ReductionOp.Min: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) - init_val.dtype = dtype case ReductionOp.Max: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) - init_val.dtype = dtype case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") - # set reduction symbol properties (local/pointer variables) in context - self._ctx.add_local_reduction_symbol(new_lhs_symb, - LocalReductionVariable(expr.op, init_val, orig_lhs_symb_as_ptr)) - self._ctx.add_reduction_ptr_symbol(orig_lhs_symb, orig_lhs_symb_as_ptr, - ReductionPointerVariable(expr.op, new_lhs_symb)) + reduction_info = ReductionInfo(expr.op, init_val, orig_lhs_symb, orig_lhs_symb_as_ptr) + + # add new symbol for local copy, replace original copy with pointer counterpart and add reduction info + self._ctx.add_symbol(new_lhs_symb) + self._ctx.add_symbol_reduction_info(new_lhs_symb, reduction_info) + self._ctx.replace_symbol(orig_lhs_symb, orig_lhs_symb_as_ptr) + return PsAssignment(new_lhs, new_rhs) diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index f4046d87dada13d13eb1926cda0c3a39d95ff326..d72008d56c4817dea668a0c1f9dcf2a2e41641a8 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -112,10 +112,10 @@ class AddOpenMP: pragma_text += " parallel" if not omp_params.omit_parallel_construct else "" pragma_text += f" for schedule({omp_params.schedule})" - if bool(ctx.local_reduction_symbols): - for symbol, reduction in ctx.local_reduction_symbols.items(): + if bool(ctx.symbols_reduction_info): + for symbol, reduction_info in ctx.symbols_reduction_info.items(): if isinstance(symbol.dtype, PsScalarType): - pragma_text += f" reduction({reduction.op.value}: {symbol.name})" + pragma_text += f" reduction({reduction_info.op.value}: {symbol.name})" else: NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.") diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index 6e0611a4b70431dbfb04c7d3f59e2fb4ca7fabc6..ba7df317acab94e30c26025d246bd06d2bea5490 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, replace from .target import Target from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange -from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable +from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr from .parameters import Parameter from ..compound_op_mapping import compound_op_to_expr from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr @@ -154,15 +154,21 @@ class DefaultKernelCreationDriver: if self._intermediates is not None: self._intermediates.constants_eliminated = kernel_ast.clone() - # Init local reduction variable copy - for local_red, local_prop in self._ctx.local_reduction_symbols.items(): - kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), local_prop.init_val)] + kernel_ast.statements + # Extensions for reductions + for symbol, reduction_info in self._ctx.symbols_reduction_info.items(): + # Init local reduction variable copy + kernel_ast.statements = [PsDeclaration(PsSymbolExpr(symbol), + reduction_info.init_val)] + kernel_ast.statements - # Write back result to reduction target variable - for red_ptr, ptr_prop in self._ctx.reduction_pointer_symbols.items(): - ptr_access = PsMemAcc(PsSymbolExpr(red_ptr), PsConstantExpr(PsConstant(0, self._ctx.index_dtype))) + # Write back result to reduction target variable + ptr_access = PsMemAcc(PsSymbolExpr(reduction_info.ptr_symbol), + PsConstantExpr(PsConstant(0))) kernel_ast.statements += [PsAssignment( - ptr_access, binop_str_to_expr(ptr_prop.op, ptr_access, PsSymbolExpr(ptr_prop.local_symbol)))] + ptr_access, compound_op_to_expr(reduction_info.op, ptr_access, PsSymbolExpr(symbol)))] + + # TODO: only newly introduced nodes + typify = Typifier(self._ctx) + kernel_ast = typify(kernel_ast) # Target-Specific optimizations if self._cfg.target.is_cpu(): @@ -462,8 +468,6 @@ def _get_function_params( props: set[PsSymbolProperty] = set() for prop in symb.properties: match prop: - case ReductionPointerVariable(): - props.add(prop) case FieldShape() | FieldStride(): props.add(prop) case BufferBasePtr(buf): diff --git a/src/pystencils/codegen/parameters.py b/src/pystencils/codegen/parameters.py index 094553517a56a189f4fa714749b5f7b5761f8e33..e6a513cc7cec928373b11e3ea46e703c27aa2443 100644 --- a/src/pystencils/codegen/parameters.py +++ b/src/pystencils/codegen/parameters.py @@ -8,7 +8,7 @@ from .properties import ( _FieldProperty, FieldShape, FieldStride, - FieldBasePtr, ReductionPointerVariable, + FieldBasePtr, ) from ..types import PsType from ..field import Field @@ -39,9 +39,6 @@ class Parameter: key=lambda f: f.name, ) ) - self._reduction_ptr: Optional[ReductionPointerVariable] = next( - (e for e in self._properties if isinstance(e, ReductionPointerVariable)), None - ) @property def name(self): @@ -82,11 +79,6 @@ class Parameter: """Set of fields associated with this parameter.""" return self._fields - @property - def reduction_pointer(self) -> Optional[ReductionPointerVariable]: - """Reduction pointer associated with this parameter.""" - return self._reduction_ptr - def get_properties( self, prop_type: type[PsSymbolProperty] | tuple[type[PsSymbolProperty], ...] ) -> set[PsSymbolProperty]: @@ -113,10 +105,6 @@ class Parameter: ) return bool(self.get_properties(FieldBasePtr)) - @property - def is_reduction_pointer(self) -> bool: - return bool(self._reduction_ptr) - @property def is_field_stride(self) -> bool: # pragma: no cover warn( diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py index d3c2435ed713ec6eccb55c4afba968a708bac29d..d377fb3d35d99b59c4f364cc4d066b736bfd9140 100644 --- a/src/pystencils/codegen/properties.py +++ b/src/pystencils/codegen/properties.py @@ -2,7 +2,6 @@ from __future__ import annotations from dataclasses import dataclass from ..field import Field -from ..sympyextensions.reduction import ReductionOp @dataclass(frozen=True) @@ -15,28 +14,6 @@ class UniqueSymbolProperty(PsSymbolProperty): """Base class for unique properties, of which only one instance may be registered at a time.""" -@dataclass(frozen=True) -class LocalReductionVariable(PsSymbolProperty): - """Property for symbols specifying the operation and initial value for a kernel-local reduction variable.""" - - from ..backend.memory import PsSymbol - from ..backend.ast.expressions import PsExpression - - op: ReductionOp - init_val: PsExpression - ptr_symbol: PsSymbol - - -@dataclass(frozen=True) -class ReductionPointerVariable(PsSymbolProperty): - """Property for pointer-type symbols exporting the reduction result from the kernel.""" - - from ..backend.memory import PsSymbol - - op: ReductionOp - local_symbol: PsSymbol - - @dataclass(frozen=True) class FieldShape(PsSymbolProperty): """Symbol acts as a shape parameter to a field.""" diff --git a/src/pystencils/jit/cpu_extension_module.py b/src/pystencils/jit/cpu_extension_module.py index 6ec62c28d977eb0cdb6a218a6ab7b6f4a818db2a..44185f4ed9522266e23e534f320fd85e3b0bdabd 100644 --- a/src/pystencils/jit/cpu_extension_module.py +++ b/src/pystencils/jit/cpu_extension_module.py @@ -18,7 +18,7 @@ from ..types import ( PsType, PsUnsignedIntegerType, PsSignedIntegerType, - PsIeeeFloatType, + PsIeeeFloatType, PsPointerType, ) from ..types.quick import Fp, SInt, UInt @@ -205,7 +205,7 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ self._array_assoc_var_extractions: dict[Parameter, str] = dict() self._scalar_extractions: dict[Parameter, str] = dict() - self._reduction_ptrs: dict[Parameter, str] = dict() + self._pointer_extractions: dict[Parameter, str] = dict() self._constraint_checks: list[str] = [] @@ -278,9 +278,9 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name - def extract_reduction_ptr(self, param: Parameter) -> str: - if param not in self._reduction_ptrs: - ptr = param.reduction_pointer + def extract_ptr(self, param: Parameter) -> str: + if param not in self._pointer_extractions: + ptr = param.symbol buffer = self.extract_buffer(ptr, param.name, param.dtype) code = f"{param.dtype.c_string()} {param.name} = ({param.dtype}) {buffer}.buf;" @@ -317,10 +317,10 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{ return param.name def extract_parameter(self, param: Parameter): - if param.is_reduction_pointer: - self.extract_reduction_ptr(param) - elif param.is_field_parameter: + if param.is_field_parameter: self.extract_array_assoc_var(param) + elif isinstance(param.dtype, PsPointerType): + self.extract_ptr(param) else: self.extract_scalar(param) diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py index c84417ac7b98e187106669c1232ab9edc362d0eb..69b75e711c1e8c306486f74bc0a43b8daea924b1 100644 --- a/tests/kernelcreation/test_reduction.py +++ b/tests/kernelcreation/test_reduction.py @@ -18,7 +18,7 @@ SOLUTION = { @pytest.mark.parametrize('dtype', ["float64"]) -@pytest.mark.parametrize("op", ["+", "-", "*"]) #, "min", "max"]) # TODO: min/max broken due to error in BasePrinter +@pytest.mark.parametrize("op", ["+", "-", "*", "min", "max"]) def test_reduction(dtype, op): gpu_avail = False