Skip to content
Snippets Groups Projects
Commit e94c4980 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Adapt reduction assignment interface and employ enums instead of strings for...

Adapt reduction assignment interface and employ enums instead of strings for the binary operation employed
parent f71ce708
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -39,13 +39,12 @@ from .sympyextensions.typed_sympy import TypedSymbol, DynamicType ...@@ -39,13 +39,12 @@ from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import SymbolCreator from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling from .datahandling import create_data_handling
from .sympyextensions.reduction import ( from .sympyextensions.reduction import (
AddReducedAssignment, AddReductionAssignment,
SubReducedAssignment, SubReductionAssignment,
MulReducedAssignment, MulReductionAssignment,
MinReducedssignment, MinReductionAssignment,
MaxReducedssignment MaxReductionAssignment,
) )
from .binop_mapping import binop_str_to_expr
__all__ = [ __all__ = [
"Field", "Field",
...@@ -76,13 +75,12 @@ __all__ = [ ...@@ -76,13 +75,12 @@ __all__ = [
"inspect", "inspect",
"AssignmentCollection", "AssignmentCollection",
"Assignment", "Assignment",
"binop_str_to_expr",
"AddAugmentedAssignment", "AddAugmentedAssignment",
"AddReducedAssignment", "AddReductionAssignment",
"SubReducedAssignment", "SubReductionAssignment",
"MulReducedAssignment", "MulReductionAssignment",
"MinReducedssignment", "MinReductionAssignment",
"MaxReducedssignment", "MaxReductionAssignment",
"assignment_from_stencil", "assignment_from_stencil",
"SymbolCreator", "SymbolCreator",
"create_data_handling", "create_data_handling",
......
...@@ -13,10 +13,10 @@ from ...sympyextensions import ( ...@@ -13,10 +13,10 @@ from ...sympyextensions import (
integer_functions, integer_functions,
ConditionalFieldAccess, ConditionalFieldAccess,
) )
from ...binop_mapping import binop_str_to_expr from ...compound_op_mapping import compound_op_to_expr
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc from ...sympyextensions.pointers import AddressOf, mem_acc
from ...sympyextensions.reduction import ReducedAssignment from ...sympyextensions.reduction import ReductionAssignment, ReductionOp
from ...field import Field, FieldType from ...field import Field, FieldType
from .context import KernelCreationContext from .context import KernelCreationContext
...@@ -173,9 +173,16 @@ class FreezeExpressions: ...@@ -173,9 +173,16 @@ class FreezeExpressions:
assert isinstance(lhs, PsExpression) assert isinstance(lhs, PsExpression)
assert isinstance(rhs, PsExpression) assert isinstance(rhs, PsExpression)
return PsAssignment(lhs, binop_str_to_expr(expr.op[0], lhs.clone(), rhs)) _str_to_compound_op: dict[str, ReductionOp] = {
"+=": ReductionOp.Add,
"-=": ReductionOp.Sub,
"*=": ReductionOp.Mul,
"/=": ReductionOp.Div,
}
def map_ReducedAssignment(self, expr: ReducedAssignment): return PsAssignment(lhs, compound_op_to_expr(_str_to_compound_op[expr.op], lhs.clone(), rhs))
def map_ReductionAssignment(self, expr: ReductionAssignment):
assert isinstance(expr.lhs, TypedSymbol) assert isinstance(expr.lhs, TypedSymbol)
lhs = self.visit(expr.lhs) lhs = self.visit(expr.lhs)
...@@ -197,21 +204,21 @@ class FreezeExpressions: ...@@ -197,21 +204,21 @@ class FreezeExpressions:
new_lhs = PsSymbolExpr(new_lhs_symb) new_lhs = PsSymbolExpr(new_lhs_symb)
# get new rhs from augmented assignment # get new rhs from augmented assignment
new_rhs: PsExpression = binop_str_to_expr(expr.op, new_lhs.clone(), rhs) new_rhs: PsExpression = compound_op_to_expr(expr.op, new_lhs.clone(), rhs)
# match for reduction operation and set neutral init_val # match for reduction operation and set neutral init_val
init_val: PsExpression init_val: PsExpression
match expr.op: match expr.op:
case "+": case ReductionOp.Add:
init_val = PsConstantExpr(PsConstant(0, dtype)) init_val = PsConstantExpr(PsConstant(0, dtype))
case "-": case ReductionOp.Sub:
init_val = PsConstantExpr(PsConstant(0, dtype)) init_val = PsConstantExpr(PsConstant(0, dtype))
case "*": case ReductionOp.Mul:
init_val = PsConstantExpr(PsConstant(1, dtype)) init_val = PsConstantExpr(PsConstant(1, dtype))
case "min": case ReductionOp.Min:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), []) init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
init_val.dtype = dtype init_val.dtype = dtype
case "max": case ReductionOp.Max:
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), []) init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
init_val.dtype = dtype init_val.dtype = dtype
case _: case _:
......
...@@ -115,7 +115,7 @@ class AddOpenMP: ...@@ -115,7 +115,7 @@ class AddOpenMP:
if bool(ctx.local_reduction_symbols): if bool(ctx.local_reduction_symbols):
for symbol, reduction in ctx.local_reduction_symbols.items(): for symbol, reduction in ctx.local_reduction_symbols.items():
if isinstance(symbol.dtype, PsScalarType): if isinstance(symbol.dtype, PsScalarType):
pragma_text += f" reduction({reduction.op}: {symbol.name})" pragma_text += f" reduction({reduction.op.value}: {symbol.name})"
else: else:
NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.") NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.")
......
...@@ -7,7 +7,7 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO ...@@ -7,7 +7,7 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable
from .parameters import Parameter from .parameters import Parameter
from ..binop_mapping import binop_str_to_expr from ..compound_op_mapping import compound_op_to_expr
from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
from ..types import create_numeric_type, PsIntegerType, PsScalarType from ..types import create_numeric_type, PsIntegerType, PsScalarType
......
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from ..field import Field from ..field import Field
from ..sympyextensions.reduction import ReductionOp
@dataclass(frozen=True) @dataclass(frozen=True)
...@@ -21,7 +22,7 @@ class LocalReductionVariable(PsSymbolProperty): ...@@ -21,7 +22,7 @@ class LocalReductionVariable(PsSymbolProperty):
from ..backend.memory import PsSymbol from ..backend.memory import PsSymbol
from ..backend.ast.expressions import PsExpression from ..backend.ast.expressions import PsExpression
op: str op: ReductionOp
init_val: PsExpression init_val: PsExpression
ptr_symbol: PsSymbol ptr_symbol: PsSymbol
...@@ -32,7 +33,7 @@ class ReductionPointerVariable(PsSymbolProperty): ...@@ -32,7 +33,7 @@ class ReductionPointerVariable(PsSymbolProperty):
from ..backend.memory import PsSymbol from ..backend.memory import PsSymbol
op: str op: ReductionOp
local_symbol: PsSymbol local_symbol: PsSymbol
......
from enum import Enum
from operator import truediv, mul, sub, add from operator import truediv, mul, sub, add
from .backend.ast.expressions import PsExpression, PsCall from .backend.ast.expressions import PsExpression, PsCall
from .backend.exceptions import FreezeError from .backend.exceptions import FreezeError
from .backend.functions import PsMathFunction, MathFunctions from .backend.functions import PsMathFunction, MathFunctions
from .sympyextensions.reduction import ReductionOp
_available_operator_interface: set[str] = {'+', '-', '*', '/'} _available_operator_interface: set[ReductionOp] = {ReductionOp.Add, ReductionOp.Sub, ReductionOp.Mul, ReductionOp.Div}
def binop_str_to_expr(op: str, op1, op2) -> PsExpression: def compound_op_to_expr(op: ReductionOp, op1, op2) -> PsExpression:
if op in _available_operator_interface: if op in _available_operator_interface:
match op: match op:
case "+": case ReductionOp.Add:
operator = add operator = add
case "-": case ReductionOp.Sub:
operator = sub operator = sub
case "*": case ReductionOp.Mul:
operator = mul operator = mul
case "/": case ReductionOp.Div:
operator = truediv operator = truediv
case _: case _:
raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.")
return operator(op1, op2) return operator(op1, op2)
else: else:
match op: match op:
case "min": case ReductionOp.Min:
return PsCall(PsMathFunction(MathFunctions.Min), [op1, op2]) return PsCall(PsMathFunction(MathFunctions.Min), [op1, op2])
case "max": case ReductionOp.Max:
return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2]) return PsCall(PsMathFunction(MathFunctions.Max), [op1, op2])
case _: case _:
raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.") raise FreezeError(f"Found unsupported operation type for compound assignments: {op}.")
from .astnodes import ConditionalFieldAccess from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc from .typed_sympy import TypedSymbol, CastFunc
from .pointers import mem_acc from .pointers import mem_acc
from .reduction import reduced_assign from .reduction import reduction_assignment, reduction_assignment_from_str, ReductionOp
from .math import ( from .math import (
prod, prod,
...@@ -34,7 +34,9 @@ from .math import ( ...@@ -34,7 +34,9 @@ from .math import (
__all__ = [ __all__ = [
"ConditionalFieldAccess", "ConditionalFieldAccess",
"reduced_assign", "reduction_assignment",
"reduction_assignment_from_str",
"ReductionOp",
"TypedSymbol", "TypedSymbol",
"CastFunc", "CastFunc",
"mem_acc", "mem_acc",
......
from enum import Enum
from sympy.codegen.ast import AssignmentBase from sympy.codegen.ast import AssignmentBase
class ReducedAssignment(AssignmentBase): class ReductionOp(Enum):
Add = "+"
Sub = "-"
Mul = "*"
Div = "/"
Min = "min"
Max = "max"
class ReductionAssignment(AssignmentBase):
""" """
Base class for reduced assignments. Base class for reduced assignments.
Attributes: Attributes:
=========== ===========
binop : str binop : CompoundOp
Symbol for binary operation being applied in the assignment, such as "+", Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc.
"*", etc.
""" """
binop = None # type: str binop = None # type: ReductionOp
@property @property
def op(self): def op(self):
return self.binop return self.binop
class AddReducedAssignment(ReducedAssignment): class AddReductionAssignment(ReductionAssignment):
binop = '+' binop = ReductionOp.Add
class SubReducedAssignment(ReducedAssignment): class SubReductionAssignment(ReductionAssignment):
binop = '-' binop = ReductionOp.Sub
class MulReducedAssignment(ReducedAssignment): class MulReductionAssignment(ReductionAssignment):
binop = '*' binop = ReductionOp.Mul
class MinReducedssignment(ReducedAssignment): class MinReductionAssignment(ReductionAssignment):
binop = 'min' binop = ReductionOp.Min
class MaxReducedssignment(ReducedAssignment): class MaxReductionAssignment(ReductionAssignment):
binop = 'max' binop = ReductionOp.Max
# Mapping from binary op strings to AugmentedAssignment subclasses # Mapping from ReductionOp enum to ReductionAssigment classes
reduced_assign_classes = { _reduction_assignment_classes = {
cls.binop: cls for cls in [ cls.binop: cls for cls in [
AddReducedAssignment, SubReducedAssignment, MulReducedAssignment, AddReductionAssignment, SubReductionAssignment, MulReductionAssignment,
MinReducedssignment, MaxReducedssignment MinReductionAssignment, MaxReductionAssignment
] ]
} }
# Mapping from ReductionOp str to ReductionAssigment classes
_reduction_assignment_classes_for_str = {
cls.value: cls for cls in _reduction_assignment_classes
}
def reduced_assign(lhs, op, rhs):
if op not in reduced_assign_classes: def reduction_assignment(lhs, op: ReductionOp, rhs):
if op not in _reduction_assignment_classes:
raise ValueError("Unrecognized operator %s" % op) raise ValueError("Unrecognized operator %s" % op)
return reduced_assign_classes[op](lhs, rhs) return _reduction_assignment_classes[op](lhs, rhs)
def reduction_assignment_from_str(lhs, op: str, rhs):
return reduction_assignment(lhs, _reduction_assignment_classes_for_str[op], rhs)
import pytest import pytest
import numpy as np import numpy as np
import sympy as sp
import cupy as cp import cupy as cp
import pystencils as ps import pystencils as ps
from pystencils.sympyextensions import reduced_assign from pystencils.sympyextensions import reduction_assignment_from_str
INIT_W = 5 INIT_W = 5
INIT_ARR = 2 INIT_ARR = 2
...@@ -28,11 +27,11 @@ def test_reduction(dtype, op): ...@@ -28,11 +27,11 @@ def test_reduction(dtype, op):
# kernel with reduction assignment # kernel with reduction assignment
reduction_assignment = reduced_assign(w, op, x.center()) red_assign = reduction_assignment_from_str(w, op, x.center())
config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True) config = ps.CreateKernelConfig(target=ps.Target.GPU) if gpu_avail else ps.CreateKernelConfig(cpu_openmp=True)
ast_reduction = ps.create_kernel([reduction_assignment], config, default_dtype=dtype) ast_reduction = ps.create_kernel([red_assign], config, default_dtype=dtype)
# code_reduction = ps.get_code_str(ast_reduction) # code_reduction = ps.get_code_str(ast_reduction)
kernel_reduction = ast_reduction.compile() kernel_reduction = ast_reduction.compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment