Skip to content
Snippets Groups Projects
Commit 8a2bf664 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Initial work for introducing reduction capabilities to pystencils

parent 131fa9ca
No related branches found
No related tags found
1 merge request!438Reduction Support
from sympyextensions.reduction import ReducedAssignment
from typing import overload, cast, Any
from functools import reduce
from operator import add, mul, sub, truediv
......@@ -183,6 +184,32 @@ class FreezeExpressions:
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_ReducedAssignment(self, expr: ReducedAssignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression)
match expr.op:
case "+=":
op = add
case "-=":
op = sub
case "*=":
op = mul
case "/=":
op = truediv
# TODO: unsure if sp.Min & sp.Max work here
case "min=":
op = sp.Min
case "max=":
op = sp.Max
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
return PsAssignment(lhs, op(lhs.clone(), rhs)) # TODO: PsReducedAssignment?
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
return PsSymbolExpr(symb)
......
import itertools
from copy import copy
from sympyextensions import reduced_assign
from sympyextensions.reduction import ReducedAssignment
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
......@@ -55,8 +58,13 @@ class AssignmentCollection:
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
# filter out reduced assignments
reduced_assignments = [a for a in main_assignments if isinstance(a, ReducedAssignment)]
main_assignments = [a for a in main_assignments if (a not in reduced_assignments)]
self.main_assignments = main_assignments
self.subexpressions = subexpressions
self.reductions = reduced_assignments
if simplification_hints is None:
simplification_hints = {}
......@@ -71,6 +79,11 @@ class AssignmentCollection:
else:
self.subexpression_symbol_generator = subexpression_symbol_generator
def add_reduction(self, lhs: sp.Symbol, op: str, rhs: sp.Expr) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert lhs not in self.reductions, f"Reduction for lhs {lhs} exists"
self.reductions.append(reduced_assign(lhs, op, rhs))
def add_simplification_hint(self, key: str, value: Any) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert key not in self.simplification_hints, "This hint already exists"
......
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .pointers import mem_acc
from .reduction import reduced_assign
from .math import (
prod,
......@@ -33,6 +34,7 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"reduced_assign",
"TypedSymbol",
"CastFunc",
"mem_acc",
......
from sympy.codegen.ast import AssignmentBase
class ReducedAssignment(AssignmentBase):
"""
Base class for reduced assignments.
Attributes:
===========
binop : str
Symbol for binary operation being applied in the assignment, such as "+",
"*", etc.
"""
binop = None # type: str
# TODO: initial value
@property
def op(self):
return self.binop + '='
class AddReducedAssignment(ReducedAssignment):
binop = '+'
class SubReducedAssignment(ReducedAssignment):
binop = '-'
class MulReducedAssignment(ReducedAssignment):
binop = '*'
class DivReducedAssignment(ReducedAssignment):
binop = '/'
class MinReducedssignment(ReducedAssignment):
binop = 'min'
class MaxReducedssignment(ReducedAssignment):
binop = 'max'
# Mapping from binary op strings to AugmentedAssignment subclasses
reduced_assign_classes = {
cls.binop: cls for cls in [
AddReducedAssignment, SubReducedAssignment, MulReducedAssignment, DivReducedAssignment,
MinReducedssignment, MaxReducedssignment
]
}
def reduced_assign(lhs, op, rhs):
if op not in reduced_assign_classes:
raise ValueError("Unrecognized operator %s" % op)
return reduced_assign_classes[op](lhs, rhs)
\ No newline at end of file
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
from sympyextensions.reduction import reduced_assign
@pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_log(dtype):
a = sp.Symbol("a")
x = ps.fields(f'x: {dtype}[1d]')
# kernel with main assignments and no reduction
main_assignment = ps.AssignmentCollection({x.center(): a})
ast_main = ps.create_kernel(main_assignment, default_dtype=dtype)
code_main = ps.get_code_str(ast_main)
kernel_main = ast_main.compile()
# ps.show_code(ast)
if dtype == "float64":
assert "float" not in code_main
array = np.zeros((10,), dtype=dtype)
kernel_main(x=array, a=100)
assert np.allclose(array, 4.60517019)
# kernel with single reduction assignment
omega = sp.Symbol("omega")
reduction_assignment = reduced_assign(omega, "+", x.center())
ast_reduction = ps.create_kernel(reduction_assignment, default_dtype=dtype)
code_reduction = ps.get_code_str(ast_reduction)
kernel_reduction = ast_reduction.compile()
if dtype == "float64":
assert "float" not in code_reduction
ps.show_code(ast_reduction)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment