Skip to content
Snippets Groups Projects
Commit 6b8bff09 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Initial work for introducing reduction capabilities to pystencils


Signed-off-by: default avatarzy69guqi <richard.angersbach@fau.de>
parent 4f8e42e6
Branches
1 merge request!438Reduction Support
from sympyextensions.reduction import ReducedAssignment
from typing import overload, cast, Any
from functools import reduce
from operator import add, mul, sub, truediv
......@@ -183,6 +184,32 @@ class FreezeExpressions:
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_ReducedAssignment(self, expr: ReducedAssignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
match expr.op:
case "+=":
op = add
case "-=":
op = sub
case "*=":
op = mul
case "/=":
op = truediv
# TODO: unsure if sp.Min & sp.Max work here
case "min=":
op = sp.Min
case "max=":
op = sp.Max
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs)) # TODO: PsReducedAssignment?
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
return PsSymbolExpr(symb)
......
import itertools
from copy import copy
from sympyextensions import reduced_assign
from sympyextensions.reduction import ReducedAssignment
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
......@@ -55,8 +58,13 @@ class AssignmentCollection:
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
# filter out reduced assignments
reduced_assignments = [a for a in main_assignments if isinstance(a, ReducedAssignment)]
main_assignments = [a for a in main_assignments if (a not in reduced_assignments)]
self.main_assignments = main_assignments
self.subexpressions = subexpressions
self.reductions = reduced_assignments
if simplification_hints is None:
simplification_hints = {}
......@@ -71,6 +79,11 @@ class AssignmentCollection:
else:
self.subexpression_symbol_generator = subexpression_symbol_generator
def add_reduction(self, lhs: sp.Symbol, op: str, rhs: sp.Expr) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert lhs not in self.reductions, f"Reduction for lhs {lhs} exists"
self.reductions.append(reduced_assign(lhs, op, rhs))
def add_simplification_hint(self, key: str, value: Any) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert key not in self.simplification_hints, "This hint already exists"
......
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .pointers import mem_acc
from .reduction import reduced_assign
from .math import (
prod,
......@@ -33,6 +34,7 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"reduced_assign",
"TypedSymbol",
"CastFunc",
"mem_acc",
......
from sympy.codegen.ast import AssignmentBase
class ReducedAssignment(AssignmentBase):
"""
Base class for reduced assignments.
Attributes:
===========
binop : str
Symbol for binary operation being applied in the assignment, such as "+",
"*", etc.
"""
binop = None # type: str
# TODO: initial value
@property
def op(self):
return self.binop + '='
class AddReducedAssignment(ReducedAssignment):
binop = '+'
class SubReducedAssignment(ReducedAssignment):
binop = '-'
class MulReducedAssignment(ReducedAssignment):
binop = '*'
class DivReducedAssignment(ReducedAssignment):
binop = '/'
class MinReducedssignment(ReducedAssignment):
binop = 'min'
class MaxReducedssignment(ReducedAssignment):
binop = 'max'
# Mapping from binary op strings to AugmentedAssignment subclasses
reduced_assign_classes = {
cls.binop: cls for cls in [
AddReducedAssignment, SubReducedAssignment, MulReducedAssignment, DivReducedAssignment,
MinReducedssignment, MaxReducedssignment
]
}
def reduced_assign(lhs, op, rhs):
if op not in reduced_assign_classes:
raise ValueError("Unrecognized operator %s" % op)
return reduced_assign_classes[op](lhs, rhs)
\ No newline at end of file
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
from sympyextensions.reduction import reduced_assign
@pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_log(dtype):
a = sp.Symbol("a")
x = ps.fields(f'x: {dtype}[1d]')
# kernel with main assignments and no reduction
main_assignment = ps.AssignmentCollection({x.center(): a})
ast_main = ps.create_kernel(main_assignment, default_dtype=dtype)
code_main = ps.get_code_str(ast_main)
kernel_main = ast_main.compile()
# ps.show_code(ast)
if dtype == "float64":
assert "float" not in code_main
array = np.zeros((10,), dtype=dtype)
kernel_main(x=array, a=100)
assert np.allclose(array, 4.60517019)
# kernel with single reduction assignment
omega = sp.Symbol("omega")
reduction_assignment = reduced_assign(omega, "+", x.center())
ast_reduction = ps.create_kernel(reduction_assignment, default_dtype=dtype)
code_reduction = ps.get_code_str(ast_reduction)
kernel_reduction = ast_reduction.compile()
if dtype == "float64":
assert "float" not in code_reduction
ps.show_code(ast_reduction)
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment