diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 44ee170775982916e34ab6da6656461962763cd4..65be23065ea0bacf721af684ce7e67dc6a789f48 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -1,3 +1,4 @@ +from sympyextensions.reduction import ReducedAssignment from typing import overload, cast, Any from functools import reduce from operator import add, mul, sub, truediv @@ -183,6 +184,32 @@ class FreezeExpressions: return PsAssignment(lhs, op(lhs.clone(), rhs)) + def map_ReducedAssignment(self, expr: ReducedAssignment): + lhs = self.visit(expr.lhs) + rhs = self.visit(expr.rhs) + + assert isinstance(lhs, PsExpression) + assert isinstance(rhs, PsExpression) + + match expr.op: + case "+=": + op = add + case "-=": + op = sub + case "*=": + op = mul + case "/=": + op = truediv + # TODO: unsure if sp.Min & sp.Max work here + case "min=": + op = sp.Min + case "max=": + op = sp.Max + case _: + raise FreezeError(f"Unsupported reduced assignment: {expr.op}.") + + return PsAssignment(lhs, op(lhs.clone(), rhs)) # TODO: PsReducedAssignment? + def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr: symb = self._ctx.get_symbol(spsym.name) return PsSymbolExpr(symb) diff --git a/src/pystencils/simp/assignment_collection.py b/src/pystencils/simp/assignment_collection.py index f1ba8715431d96fb2a09a01e45872def421fe94f..4de3e8dc69663721dac224f3e39f52f1ebb78c47 100644 --- a/src/pystencils/simp/assignment_collection.py +++ b/src/pystencils/simp/assignment_collection.py @@ -1,5 +1,8 @@ import itertools from copy import copy + +from sympyextensions import reduced_assign +from sympyextensions.reduction import ReducedAssignment from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union import sympy as sp @@ -55,8 +58,13 @@ class AssignmentCollection: subexpressions = list(itertools.chain.from_iterable( [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) + # filter out reduced assignments + reduced_assignments = [a for a in main_assignments if isinstance(a, ReducedAssignment)] + main_assignments = [a for a in main_assignments if (a not in reduced_assignments)] + self.main_assignments = main_assignments self.subexpressions = subexpressions + self.reductions = reduced_assignments if simplification_hints is None: simplification_hints = {} @@ -71,6 +79,11 @@ class AssignmentCollection: else: self.subexpression_symbol_generator = subexpression_symbol_generator + def add_reduction(self, lhs: sp.Symbol, op: str, rhs: sp.Expr) -> None: + """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" + assert lhs not in self.reductions, f"Reduction for lhs {lhs} exists" + self.reductions.append(reduced_assign(lhs, op, rhs)) + def add_simplification_hint(self, key: str, value: Any) -> None: """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" assert key not in self.simplification_hints, "This hint already exists" diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index 7431416c9eb9bcd4433dab76c32fb1b755501105..6ab24e936a2355782755badbf835d7e5c3bee73e 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,6 +1,7 @@ from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc from .pointers import mem_acc +from .reduction import reduced_assign from .math import ( prod, @@ -33,6 +34,7 @@ from .math import ( __all__ = [ "ConditionalFieldAccess", + "reduced_assign", "TypedSymbol", "CastFunc", "mem_acc", diff --git a/src/pystencils/sympyextensions/reduction.py b/src/pystencils/sympyextensions/reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..aa947c1d2360b91685c9ee2db769742607f98883 --- /dev/null +++ b/src/pystencils/sympyextensions/reduction.py @@ -0,0 +1,57 @@ +from sympy.codegen.ast import AssignmentBase + + +class ReducedAssignment(AssignmentBase): + """ + Base class for reduced assignments. + + Attributes: + =========== + + binop : str + Symbol for binary operation being applied in the assignment, such as "+", + "*", etc. + """ + binop = None # type: str + + # TODO: initial value + + @property + def op(self): + return self.binop + '=' + + +class AddReducedAssignment(ReducedAssignment): + binop = '+' + +class SubReducedAssignment(ReducedAssignment): + binop = '-' + + +class MulReducedAssignment(ReducedAssignment): + binop = '*' + + +class DivReducedAssignment(ReducedAssignment): + binop = '/' + + +class MinReducedssignment(ReducedAssignment): + binop = 'min' + +class MaxReducedssignment(ReducedAssignment): + binop = 'max' + + +# Mapping from binary op strings to AugmentedAssignment subclasses +reduced_assign_classes = { + cls.binop: cls for cls in [ + AddReducedAssignment, SubReducedAssignment, MulReducedAssignment, DivReducedAssignment, + MinReducedssignment, MaxReducedssignment + ] +} + +def reduced_assign(lhs, op, rhs): + if op not in reduced_assign_classes: + raise ValueError("Unrecognized operator %s" % op) + return reduced_assign_classes[op](lhs, rhs) \ No newline at end of file diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py new file mode 100644 index 0000000000000000000000000000000000000000..47509e267be0aac8612d231788a4d7d707d1075f --- /dev/null +++ b/tests/kernelcreation/test_reduction.py @@ -0,0 +1,44 @@ +import pytest +import numpy as np +import sympy as sp + +import pystencils as ps +from sympyextensions.reduction import reduced_assign + + +@pytest.mark.parametrize('dtype', ["float64", "float32"]) +def test_log(dtype): + a = sp.Symbol("a") + x = ps.fields(f'x: {dtype}[1d]') + + # kernel with main assignments and no reduction + + main_assignment = ps.AssignmentCollection({x.center(): a}) + + ast_main = ps.create_kernel(main_assignment, default_dtype=dtype) + code_main = ps.get_code_str(ast_main) + kernel_main = ast_main.compile() + + # ps.show_code(ast) + + if dtype == "float64": + assert "float" not in code_main + + array = np.zeros((10,), dtype=dtype) + kernel_main(x=array, a=100) + assert np.allclose(array, 4.60517019) + + # kernel with single reduction assignment + + omega = sp.Symbol("omega") + + reduction_assignment = reduced_assign(omega, "+", x.center()) + + ast_reduction = ps.create_kernel(reduction_assignment, default_dtype=dtype) + code_reduction = ps.get_code_str(ast_reduction) + kernel_reduction = ast_reduction.compile() + + if dtype == "float64": + assert "float" not in code_reduction + + ps.show_code(ast_reduction) \ No newline at end of file