diff --git a/src/pystencils/backend/kernelcreation/context.py b/src/pystencils/backend/kernelcreation/context.py index 39fb8ef6dac855553b7e18d2a688c67ca45fb227..4b4604a2153d7be40f750140566703f6c02b6355 100644 --- a/src/pystencils/backend/kernelcreation/context.py +++ b/src/pystencils/backend/kernelcreation/context.py @@ -75,6 +75,8 @@ class KernelCreationContext: self._symbol_ctr_pattern = re.compile(r"__[0-9]+$") self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0) + # TODO: add list of reduction symbols + self._fields_and_arrays: dict[str, FieldArrayPair] = dict() self._fields_collection = FieldsInKernel() diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 4d75f1ca68dce2e09b97fb1d6e226d7c351e34da..0d1ce72e1db5b958e9765b80b2d711108f4001dd 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -65,6 +65,9 @@ from ..exceptions import PsInputError from ..functions import PsMathFunction, MathFunctions from ..exceptions import FreezeError +import backend.functions +from codegen.properties import ReductionSymbolProperty + ExprLike = ( sp.Expr @@ -188,27 +191,32 @@ class FreezeExpressions: lhs = self.visit(expr.lhs) rhs = self.visit(expr.rhs) - assert isinstance(lhs, PsExpression) assert isinstance(rhs, PsExpression) + assert isinstance(lhs, PsSymbolExpr) match expr.op: - case "+=": + case "+": op = add - case "-=": + init_val = PsConstant(0) + case "-": op = sub - case "*=": + init_val = PsConstant(0) + case "*": op = mul - case "/=": - op = truediv - # TODO: unsure if sp.Min & sp.Max work here - case "min=": + init_val = PsConstant(1) + # TODO: unsure if sp.Min & sp.Max are mapped by map_Min/map_Max afterwards + case "min": op = sp.Min - case "max=": + init_val = backend.functions.NumericLimitsFunctions("min") + case "max": op = sp.Max + init_val = backend.functions.NumericLimitsFunctions("max") case _: raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") - return PsAssignment(lhs, op(lhs.clone(), rhs)) # TODO: PsReducedAssignment? + lhs.symbol.add_property(ReductionSymbolProperty(expr.op, init_val)) + + return PsAssignment(lhs, op(lhs.clone(), rhs)) def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py index d377fb3d35d99b59c4f364cc4d066b736bfd9140..5578d24084b24bd4957649391c9552d8c51557ee 100644 --- a/src/pystencils/codegen/properties.py +++ b/src/pystencils/codegen/properties.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from ..field import Field +from backend.ast.expressions import PsExpression + @dataclass(frozen=True) class PsSymbolProperty: @@ -14,6 +16,14 @@ class UniqueSymbolProperty(PsSymbolProperty): """Base class for unique properties, of which only one instance may be registered at a time.""" +@dataclass(frozen=True) +class ReductionSymbolProperty(UniqueSymbolProperty): + """Symbol acts as a base pointer to a field.""" + + op: str + init_val: PsExpression + + @dataclass(frozen=True) class FieldShape(PsSymbolProperty): """Symbol acts as a shape parameter to a field.""" diff --git a/src/pystencils/sympyextensions/reduction.py b/src/pystencils/sympyextensions/reduction.py index 90ab61ede668b17410afb2adb2f53019ef017260..e2760cc6c85faacedb35eb139afa313feb2b833c 100644 --- a/src/pystencils/sympyextensions/reduction.py +++ b/src/pystencils/sympyextensions/reduction.py @@ -12,13 +12,11 @@ class ReducedAssignment(AssignmentBase): Symbol for binary operation being applied in the assignment, such as "+", "*", etc. """ - binop = None # type: str - - # TODO: initial value + binop = None # type: str @property def op(self): - return self.binop + '=' + return self.binop class AddReducedAssignment(ReducedAssignment):