diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index 3e8e8d8e4ea6652b4f331f173266e5e02d23ee49..6aa305a1690025220fbef1228381e57840f356b2 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -39,13 +39,12 @@ from .sympyextensions.typed_sympy import TypedSymbol, DynamicType from .sympyextensions import SymbolCreator from .datahandling import create_data_handling from .sympyextensions.reduction import ( - AddReducedAssignment, - SubReducedAssignment, - MulReducedAssignment, - MinReducedssignment, - MaxReducedssignment + AddReductionAssignment, + SubReductionAssignment, + MulReductionAssignment, + MinReductionAssignment, + MaxReductionAssignment, ) -from .binop_mapping import binop_str_to_expr __all__ = [ "Field", @@ -76,13 +75,12 @@ __all__ = [ "inspect", "AssignmentCollection", "Assignment", - "binop_str_to_expr", "AddAugmentedAssignment", - "AddReducedAssignment", - "SubReducedAssignment", - "MulReducedAssignment", - "MinReducedssignment", - "MaxReducedssignment", + "AddReductionAssignment", + "SubReductionAssignment", + "MulReductionAssignment", + "MinReductionAssignment", + "MaxReductionAssignment", "assignment_from_stencil", "SymbolCreator", "create_data_handling", diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index de272cf44d3d16074774a6f02286e2428cefbe2b..4bf136562458f05c4d34776316ab86e2483bce52 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -13,10 +13,10 @@ from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, ) -from ...binop_mapping import binop_str_to_expr +from ...compound_op_mapping import compound_op_to_expr from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType from ...sympyextensions.pointers import AddressOf, mem_acc -from ...sympyextensions.reduction import ReducedAssignment +from ...sympyextensions.reduction import ReductionAssignment, ReductionOp from ...field import Field, FieldType from .context import KernelCreationContext @@ -173,9 +173,16 @@ class FreezeExpressions: assert isinstance(lhs, PsExpression) assert isinstance(rhs, PsExpression) - return PsAssignment(lhs, binop_str_to_expr(expr.op[0], lhs.clone(), rhs)) + _str_to_compound_op: dict[str, ReductionOp] = { + "+=": ReductionOp.Add, + "-=": ReductionOp.Sub, + "*=": ReductionOp.Mul, + "/=": ReductionOp.Div, + } - def map_ReducedAssignment(self, expr: ReducedAssignment): + return PsAssignment(lhs, compound_op_to_expr(_str_to_compound_op[expr.op], lhs.clone(), rhs)) + + def map_ReductionAssignment(self, expr: ReductionAssignment): assert isinstance(expr.lhs, TypedSymbol) lhs = self.visit(expr.lhs) @@ -197,21 +204,21 @@ class FreezeExpressions: new_lhs = PsSymbolExpr(new_lhs_symb) # get new rhs from augmented assignment - new_rhs: PsExpression = binop_str_to_expr(expr.op, new_lhs.clone(), rhs) + new_rhs: PsExpression = compound_op_to_expr(expr.op, new_lhs.clone(), rhs) # match for reduction operation and set neutral init_val init_val: PsExpression match expr.op: - case "+": + case ReductionOp.Add: init_val = PsConstantExpr(PsConstant(0, dtype)) - case "-": + case ReductionOp.Sub: init_val = PsConstantExpr(PsConstant(0, dtype)) - case "*": + case ReductionOp.Mul: init_val = PsConstantExpr(PsConstant(1, dtype)) - case "min": + case ReductionOp.Min: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) init_val.dtype = dtype - case "max": + case ReductionOp.Max: init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) init_val.dtype = dtype case _: diff --git a/src/pystencils/backend/transformations/add_pragmas.py b/src/pystencils/backend/transformations/add_pragmas.py index 44d1d1edee4c770d5b6f2d1bc9568ad549423b47..f4046d87dada13d13eb1926cda0c3a39d95ff326 100644 --- a/src/pystencils/backend/transformations/add_pragmas.py +++ b/src/pystencils/backend/transformations/add_pragmas.py @@ -115,7 +115,7 @@ class AddOpenMP: if bool(ctx.local_reduction_symbols): for symbol, reduction in ctx.local_reduction_symbols.items(): if isinstance(symbol.dtype, PsScalarType): - pragma_text += f" reduction({reduction.op}: {symbol.name})" + pragma_text += f" reduction({reduction.op.value}: {symbol.name})" else: NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.") diff --git a/src/pystencils/codegen/driver.py b/src/pystencils/codegen/driver.py index d68bfbcacb7933ff32da32015673a0fb75ac1a92..6e0611a4b70431dbfb04c7d3f59e2fb4ca7fabc6 100644 --- a/src/pystencils/codegen/driver.py +++ b/src/pystencils/codegen/driver.py @@ -7,7 +7,7 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO from .kernel import Kernel, GpuKernel, GpuThreadsRange from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable from .parameters import Parameter -from ..binop_mapping import binop_str_to_expr +from ..compound_op_mapping import compound_op_to_expr from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr from ..types import create_numeric_type, PsIntegerType, PsScalarType diff --git a/src/pystencils/codegen/properties.py b/src/pystencils/codegen/properties.py index 1e71c5b9855294cc47aeb06fc54b21cd1e5f31b9..d3c2435ed713ec6eccb55c4afba968a708bac29d 100644 --- a/src/pystencils/codegen/properties.py +++ b/src/pystencils/codegen/properties.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass from ..field import Field +from ..sympyextensions.reduction import ReductionOp @dataclass(frozen=True) @@ -21,7 +22,7 @@ class LocalReductionVariable(PsSymbolProperty): from ..backend.memory import PsSymbol from ..backend.ast.expressions import PsExpression - op: str + op: ReductionOp init_val: PsExpression ptr_symbol: PsSymbol @@ -32,7 +33,7 @@ class ReductionPointerVariable(PsSymbolProperty): from ..backend.memory import PsSymbol - op: str + op: ReductionOp local_symbol: PsSymbol diff --git a/src/pystencils/binop_mapping.py b/src/pystencils/compound_op_mapping.py similarity index 65% rename from src/pystencils/binop_mapping.py rename to src/pystencils/compound_op_mapping.py index 060fa40aad732922e6f14d5f62ab04e71a2ff487..eb10b3381d7cc00da6f2a198532ab54519f028b9 100644 --- a/src/pystencils/binop_mapping.py +++ b/src/pystencils/compound_op_mapping.py @@ -1,31 +1,33 @@ +from enum import Enum from operator import truediv, mul, sub, add from .backend.ast.expressions import PsExpression, PsCall from .backend.exceptions import FreezeError from .backend.functions import PsMathFunction, MathFunctions +from .sympyextensions.reduction import ReductionOp -_available_operator_interface: set[str] = {'+', '-', '*', '/'} +_available_operator_interface: set[ReductionOp] = {ReductionOp.Add, ReductionOp.Sub, ReductionOp.Mul, ReductionOp.Div} -def binop_str_to_expr(op: str, op1, op2) -> PsExpression: +def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression: if op in _available_operator_interface: match op: - case "+": + case ReductionOp.Add: operator = add - case "-": + case ReductionOp.Sub: operator = sub - case "*": + case ReductionOp.Mul: operator = mul - case "/": + case ReductionOp.Div: operator = truediv case _: raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") return operator(op1, op2) else: match op: - case "min": + case ReductionOp.Min: return PsCall(PsMathFunction(MathFunctions.Min), [op1, op2]) - case "max": + case ReductionOp.Max: return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2]) case _: raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 6ab24e936a2355782755badbf835d7e5c3bee73e..eb90f4bedde7729750073b60f4f89e7a5d602d37 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,7 +1,7 @@ from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc from .pointers import mem_acc -from .reduction import reduced_assign +from .reduction import reduction_assignment, reduction_assignment_from_str, ReductionOp from .math import ( prod, @@ -34,7 +34,9 @@ from .math import ( __all__ = [ "ConditionalFieldAccess", - "reduced_assign", + "reduction_assignment", + "reduction_assignment_from_str", + "ReductionOp", "TypedSymbol", "CastFunc", "mem_acc", diff --git a/src/pystencils/sympyextensions/reduction.py b/src/pystencils/sympyextensions/reduction.py index c9e5bfdfb38b576176d10446697232c7fdd08d64..9d8aecb5bacd3ab0dca032deac94e1ef419b7a4e 100644 --- a/src/pystencils/sympyextensions/reduction.py +++ b/src/pystencils/sympyextensions/reduction.py @@ -1,54 +1,73 @@ +from enum import Enum + from sympy.codegen.ast import AssignmentBase -class ReducedAssignment(AssignmentBase): +class ReductionOp(Enum): + Add = "+" + Sub = "-" + Mul = "*" + Div = "/" + Min = "min" + Max = "max" + + +class ReductionAssignment(AssignmentBase): """ Base class for reduced assignments. Attributes: =========== - binop : str - Symbol for binary operation being applied in the assignment, such as "+", - "*", etc. + binop : CompoundOp + Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc. """ - binop = None # type: str + binop = None # type: ReductionOp @property def op(self): return self.binop -class AddReducedAssignment(ReducedAssignment): - binop = '+' +class AddReductionAssignment(ReductionAssignment): + binop = ReductionOp.Add -class SubReducedAssignment(ReducedAssignment): - binop = '-' +class SubReductionAssignment(ReductionAssignment): + binop = ReductionOp.Sub -class MulReducedAssignment(ReducedAssignment): - binop = '*' +class MulReductionAssignment(ReductionAssignment): + binop = ReductionOp.Mul -class MinReducedssignment(ReducedAssignment): - binop = 'min' +class MinReductionAssignment(ReductionAssignment): + binop = ReductionOp.Min -class MaxReducedssignment(ReducedAssignment): - binop = 'max' +class MaxReductionAssignment(ReductionAssignment): + binop = ReductionOp.Max -# Mapping from binary op strings to AugmentedAssignment subclasses -reduced_assign_classes = { +# Mapping from ReductionOp enum to ReductionAssigment classes +_reduction_assignment_classes = { cls.binop: cls for cls in [ - AddReducedAssignment, SubReducedAssignment, MulReducedAssignment, - MinReducedssignment, MaxReducedssignment + AddReductionAssignment, SubReductionAssignment, MulReductionAssignment, + MinReductionAssignment, MaxReductionAssignment ] } +# Mapping from ReductionOp str to ReductionAssigment classes +_reduction_assignment_classes_for_str = { + cls.value: cls for cls in _reduction_assignment_classes +} -def reduced_assign(lhs, op, rhs): - if op not in reduced_assign_classes: + +def reduction_assignment(lhs, op: ReductionOp, rhs): + if op not in _reduction_assignment_classes: raise ValueError("Unrecognized operator %s" % op) - return reduced_assign_classes[op](lhs, rhs) + return _reduction_assignment_classes[op](lhs, rhs) + + +def reduction_assignment_from_str(lhs, op: str, rhs): + return reduction_assignment(lhs, _reduction_assignment_classes_for_str[op], rhs) diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py index 8095f4e1d82e6f11998642ef115f8b7978e9e61d..c84417ac7b98e187106669c1232ab9edc362d0eb 100644 --- a/tests/kernelcreation/test_reduction.py +++ b/tests/kernelcreation/test_reduction.py @@ -1,10 +1,9 @@ import pytest import numpy as np -import sympy as sp import cupy as cp import pystencils as ps -from pystencils.sympyextensions import reduced_assign +from pystencils.sympyextensions import reduction_assignment_from_str INIT_W = 5 INIT_ARR = 2 @@ -28,11 +27,11 @@ def test_reduction(dtype, op): # kernel with reduction assignment - reduction_assignment = reduced_assign(w, op, x.center()) + red_assign = reduction_assignment_from_str(w, op, x.center()) config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True) - ast_reduction = ps.create_kernel([reduction_assignment], config, default_dtype=dtype) + ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype) # code_reduction = ps.get_code_str(ast_reduction) kernel_reduction = ast_reduction.compile()