Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (27)
Showing
with 225 additions and 9 deletions
......@@ -38,6 +38,13 @@ from .simp import AssignmentCollection
from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
from .sympyextensions.reduction import (
AddReducedAssignment,
SubReducedAssignment,
MulReducedAssignment,
MinReducedssignment,
MaxReducedssignment
)
__all__ = [
"Field",
......@@ -69,6 +76,11 @@ __all__ = [
"AssignmentCollection",
"Assignment",
"AddAugmentedAssignment",
"AddReducedAssignment",
"SubReducedAssignment",
"MulReducedAssignment",
"MinReducedssignment",
"MaxReducedssignment",
"assignment_from_stencil",
"SymbolCreator",
"create_data_handling",
......
......@@ -62,7 +62,7 @@ class UndefinedSymbolsCollector:
case PsAssignment(lhs, rhs):
undefined_vars = self(lhs) | self(rhs)
if isinstance(lhs, PsSymbolExpr):
if isinstance(node, PsDeclaration) and isinstance(lhs, PsSymbolExpr):
undefined_vars.remove(lhs.symbol)
return undefined_vars
......
......@@ -94,17 +94,31 @@ class MathFunctions(Enum):
self.num_args = num_args
class NumericLimitsFunctions(Enum):
"""Numerical limits functions supported by the backend.
Each platform has to materialize these functions to a concrete implementation.
"""
Min = ("min", 0)
Max = ("max", 0)
def __init__(self, func_name, num_args):
self.function_name = func_name
self.num_args = num_args
class PsMathFunction(PsFunction):
"""Homogenously typed mathematical functions."""
__match_args__ = ("func",)
def __init__(self, func: MathFunctions) -> None:
def __init__(self, func: MathFunctions | NumericLimitsFunctions) -> None:
super().__init__(func.function_name, func.num_args)
self._func = func
@property
def func(self) -> MathFunctions:
def func(self) -> MathFunctions | NumericLimitsFunctions:
return self._func
def __str__(self) -> str:
......
......@@ -9,6 +9,8 @@ from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ...codegen.properties import ReductionSymbolProperty
from ..memory import PsSymbol, PsBuffer
from ..constants import PsConstant
from ...types import (
......@@ -75,6 +77,8 @@ class KernelCreationContext:
self._symbol_ctr_pattern = re.compile(r"__[0-9]+$")
self._symbol_dup_table: defaultdict[str, int] = defaultdict(lambda: 0)
self._symbols_with_reduction: dict[PsSymbol, ReductionSymbolProperty] = dict()
self._fields_and_arrays: dict[str, FieldArrayPair] = dict()
self._fields_collection = FieldsInKernel()
......@@ -168,6 +172,22 @@ class KernelCreationContext:
self._symbols[old.name] = new
def add_reduction_to_symbol(self, symbol: PsSymbol, reduction: ReductionSymbolProperty):
"""Adds a reduction property to a symbol.
The symbol ``symbol`` should not have a reduction property and must exist in the symbol table.
"""
if self.find_symbol(symbol.name) is None:
raise PsInternalCompilerError(
f"add_reduction_to_symbol: {symbol.name} does not exist in the symbol table"
)
if symbol not in self._symbols_with_reduction and not symbol.get_properties(ReductionSymbolProperty):
symbol.add_property(reduction)
self._symbols_with_reduction[symbol] = reduction
else:
raise PsInternalCompilerError(f"add_reduction_to_symbol: {symbol.name} already has a reduction property")
def duplicate_symbol(
self, symb: PsSymbol, new_dtype: PsType | None = None
) -> PsSymbol:
......@@ -203,6 +223,11 @@ class KernelCreationContext:
"""Return an iterable of all symbols listed in the symbol table."""
return self._symbols.values()
@property
def symbols_with_reduction(self) -> dict[PsSymbol, ReductionSymbolProperty]:
"""Return a dictionary holding symbols and their reduction property."""
return self._symbols_with_reduction
# Fields and Arrays
@property
......
......@@ -7,6 +7,7 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ..memory import PsSymbol
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
......@@ -15,6 +16,7 @@ from ...sympyextensions import (
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...sympyextensions.reduction import ReducedAssignment
from ...field import Field, FieldType
from .context import KernelCreationContext
......@@ -61,9 +63,11 @@ from ..ast.vector import PsVecMemAcc
from ..constants import PsConstant
from ...types import PsNumericType, PsStructType, PsType
from ..exceptions import PsInputError
from ..functions import PsMathFunction, MathFunctions
from ..functions import PsMathFunction, MathFunctions, NumericLimitsFunctions
from ..exceptions import FreezeError
from ...codegen.properties import ReductionSymbolProperty
ExprLike = (
sp.Expr
......@@ -183,6 +187,45 @@ class FreezeExpressions:
return PsAssignment(lhs, op(lhs.clone(), rhs))
def map_ReducedAssignment(self, expr: ReducedAssignment):
lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs)
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
# create kernel-local copy of lhs symbol to work with
new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symbol)
self._ctx.add_symbol(new_lhs_symbol)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
init_val: PsExpression
match expr.op:
case "+":
init_val = PsConstantExpr(PsConstant(0))
new_rhs = add(new_lhs.clone(), rhs)
case "-":
init_val = PsConstantExpr(PsConstant(0))
new_rhs = sub(new_lhs.clone(), rhs)
case "*":
init_val = PsConstantExpr(PsConstant(1))
new_rhs = mul(new_lhs.clone(), rhs)
case "min":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
return PsAssignment(new_lhs, new_rhs)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
return PsSymbolExpr(symb)
......
......@@ -3,8 +3,8 @@ from typing import Sequence
from pystencils.backend.ast.expressions import PsCall
from ..functions import CFunction, PsMathFunction, MathFunctions
from ...types import PsIntegerType, PsIeeeFloatType
from ..functions import CFunction, PsMathFunction, MathFunctions, NumericLimitsFunctions
from ...types import PsIntegerType, PsIeeeFloatType, PsScalarType
from .platform import Platform
from ..exceptions import MaterializationError
......@@ -43,7 +43,7 @@ class GenericCpu(Platform):
@property
def required_headers(self) -> set[str]:
return {"<math.h>"}
return {"<math.h>", "<limits.h>"}
def materialize_iteration_space(
self, body: PsBlock, ispace: IterationSpace
......@@ -62,8 +62,13 @@ class GenericCpu(Platform):
dtype = call.get_dtype()
arg_types = (dtype,) * func.num_args
if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64):
if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
cfunc: CFunction
cfunc = CFunction(f"{dtype.c_string()}_{func.function_name}".capitalize(), arg_types, dtype)
call.function = cfunc
return call
if isinstance(dtype, PsIeeeFloatType) and dtype.width in (32, 64):
match func:
case (
MathFunctions.Exp
......
......@@ -10,6 +10,8 @@ from ..ast import PsAstNode
from ..ast.structural import PsBlock, PsLoop, PsPragma
from ..ast.expressions import PsExpression
from ...types import PsScalarType
if TYPE_CHECKING:
from ...codegen.config import OpenMpConfig
......@@ -110,6 +112,13 @@ class AddOpenMP:
pragma_text += " parallel" if not omp_params.omit_parallel_construct else ""
pragma_text += f" for schedule({omp_params.schedule})"
if bool(ctx.symbols_with_reduction):
for symbol, reduction in ctx.symbols_with_reduction.items():
if isinstance(symbol.dtype, PsScalarType):
pragma_text += f" reduction({reduction.op}: {symbol.name})"
else:
NotImplementedError("OMP: Reductions for non-scalar data types are not supported yet.")
if omp_params.num_threads is not None:
pragma_text += f" num_threads({str(omp_params.num_threads)})"
......
......@@ -7,12 +7,13 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
from .parameters import Parameter
from ..backend.ast.expressions import PsSymbolExpr
from ..types import create_numeric_type, PsIntegerType, PsScalarType
from ..backend.memory import PsSymbol
from ..backend.ast import PsAstNode
from ..backend.ast.structural import PsBlock, PsLoop
from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment
from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
from ..backend.kernelcreation import (
KernelCreationContext,
......@@ -152,6 +153,14 @@ class DefaultKernelCreationDriver:
if self._intermediates is not None:
self._intermediates.constants_eliminated = kernel_ast.clone()
# Init local reduction variable copy
# for red, prop in self._ctx.symbols_with_reduction.items():
# kernel_ast.statements = [PsAssignment(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements
# Write back result to reduction target variable
# for red, prop in self._ctx.symbols_with_reduction.items():
# kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))]
# Target-Specific optimizations
if self._cfg.target.is_cpu():
kernel_ast = self._transform_for_cpu(kernel_ast)
......@@ -450,6 +459,7 @@ def _get_function_params(
props: set[PsSymbolProperty] = set()
for prop in symb.properties:
match prop:
# TODO: how to export reduction result (via pointer)?
case FieldShape() | FieldStride():
props.add(prop)
case BufferBasePtr(buf):
......
......@@ -14,6 +14,18 @@ class UniqueSymbolProperty(PsSymbolProperty):
"""Base class for unique properties, of which only one instance may be registered at a time."""
@dataclass(frozen=True)
class ReductionSymbolProperty(UniqueSymbolProperty):
"""Property for symbols specifying the operation and initial value for a reduction."""
from ..backend.memory import PsSymbol
from ..backend.ast.expressions import PsExpression
op: str
init_val: PsExpression
orig_symbol: PsSymbol
@dataclass(frozen=True)
class FieldShape(PsSymbolProperty):
"""Symbol acts as a shape parameter to a field."""
......
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
......
from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .pointers import mem_acc
from .reduction import reduced_assign
from .math import (
prod,
......@@ -33,6 +34,7 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"reduced_assign",
"TypedSymbol",
"CastFunc",
"mem_acc",
......
from sympy.codegen.ast import AssignmentBase
class ReducedAssignment(AssignmentBase):
"""
Base class for reduced assignments.
Attributes:
===========
binop : str
Symbol for binary operation being applied in the assignment, such as "+",
"*", etc.
"""
binop = None # type: str
@property
def op(self):
return self.binop
class AddReducedAssignment(ReducedAssignment):
binop = '+'
class SubReducedAssignment(ReducedAssignment):
binop = '-'
class MulReducedAssignment(ReducedAssignment):
binop = '*'
class MinReducedssignment(ReducedAssignment):
binop = 'min'
class MaxReducedssignment(ReducedAssignment):
binop = 'max'
# Mapping from binary op strings to AugmentedAssignment subclasses
reduced_assign_classes = {
cls.binop: cls for cls in [
AddReducedAssignment, SubReducedAssignment, MulReducedAssignment,
MinReducedssignment, MaxReducedssignment
]
}
def reduced_assign(lhs, op, rhs):
if op not in reduced_assign_classes:
raise ValueError("Unrecognized operator %s" % op)
return reduced_assign_classes[op](lhs, rhs)
import pytest
import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.sympyextensions import reduced_assign
@pytest.mark.parametrize('dtype', ["float64"])
@pytest.mark.parametrize("op", ["+", "-", "*", "min", "max"])
def test_reduction(dtype, op):
x = ps.fields(f'x: {dtype}[1d]')
w = sp.Symbol("w")
# kernel with reduction assignment
reduction_assignment = reduced_assign(w, op, x.center())
config = ps.CreateKernelConfig(cpu_openmp=True)
ast_reduction = ps.create_kernel([reduction_assignment], config, default_dtype=dtype)
#code_reduction = ps.get_code_str(ast_reduction)
kernel_reduction = ast_reduction.compile()
ps.show_code(ast_reduction)
array = np.ones((10,), dtype=dtype)
kernel_reduction(x=array, w=0)
# TODO: check if "w = #points"
\ No newline at end of file