Skip to content
Snippets Groups Projects
Commit 45ab4e86 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Try initializing kernel-local reduction variable copy

parent 6a7a251f
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -7,6 +7,7 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ..memory import PsSymbol
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
......@@ -193,32 +194,37 @@ class FreezeExpressions:
assert isinstance(rhs, PsExpression)
assert isinstance(lhs, PsSymbolExpr)
# create kernel-local copy of lhs symbol to work with
new_lhs_symbol = PsSymbol(f"{lhs.symbol.name}_local", lhs.dtype)
new_lhs = PsSymbolExpr(new_lhs_symbol)
self._ctx.add_symbol(new_lhs_symbol)
# match for reduction operation and set neutral init_val and new rhs (similar to augmented assignment)
new_rhs: PsExpression
init_val: Any # TODO: type?
init_val: PsExpression
match expr.op:
case "+":
init_val = PsConstant(0)
new_rhs = add(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = add(new_lhs.clone(), rhs)
case "-":
init_val = PsConstant(0)
new_rhs = sub(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(0))
new_rhs = sub(new_lhs.clone(), rhs)
case "*":
init_val = PsConstant(1)
new_rhs = mul(lhs.clone(), rhs)
init_val = PsConstantExpr(PsConstant(1))
new_rhs = mul(new_lhs.clone(), rhs)
case "min":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Min), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [lhs.clone(), rhs])
new_rhs = PsCall(PsMathFunction(MathFunctions.Min), [new_lhs.clone(), rhs])
case "max":
init_val = PsCall(PsMathFunction(NumericLimitsFunctions.Max), [])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [lhs.clone(), rhs])
new_rhs = PsCall(PsMathFunction(MathFunctions.Max), [new_lhs.clone(), rhs])
case _:
raise FreezeError(f"Unsupported reduced assignment: {expr.op}.")
# set reduction symbol property in context
self._ctx.add_reduction_to_symbol(lhs.symbol, ReductionSymbolProperty(expr.op, init_val))
self._ctx.add_reduction_to_symbol(new_lhs_symbol, ReductionSymbolProperty(expr.op, init_val, lhs.symbol))
return PsAssignment(lhs, new_rhs)
return PsAssignment(new_lhs, new_rhs)
def map_Symbol(self, spsym: sp.Symbol) -> PsSymbolExpr:
symb = self._ctx.get_symbol(spsym.name)
......
......@@ -7,12 +7,13 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
from .parameters import Parameter
from ..backend.ast.expressions import PsSymbolExpr
from ..types import create_numeric_type, PsIntegerType, PsScalarType
from ..backend.memory import PsSymbol
from ..backend.ast import PsAstNode
from ..backend.ast.structural import PsBlock, PsLoop
from ..backend.ast.structural import PsBlock, PsLoop, PsAssignment
from ..backend.ast.analysis import collect_undefined_symbols, collect_required_headers
from ..backend.kernelcreation import (
KernelCreationContext,
......@@ -152,6 +153,14 @@ class DefaultKernelCreationDriver:
if self._intermediates is not None:
self._intermediates.constants_eliminated = kernel_ast.clone()
# Init local reduction variable copy
# for red, prop in self._ctx.symbols_with_reduction.items():
# kernel_ast.statements = [PsAssignment(PsSymbolExpr(red), prop.init_val)] + kernel_ast.statements
# Write back result to reduction target variable
# for red, prop in self._ctx.symbols_with_reduction.items():
# kernel_ast.statements += [PsAssignment(PsSymbolExpr(prop.orig_symbol), PsSymbolExpr(red))]
# Target-Specific optimizations
if self._cfg.target.is_cpu():
kernel_ast = self._transform_for_cpu(kernel_ast)
......@@ -450,6 +459,7 @@ def _get_function_params(
props: set[PsSymbolProperty] = set()
for prop in symb.properties:
match prop:
# TODO: how to export reduction result (via pointer)?
case FieldShape() | FieldStride():
props.add(prop)
case BufferBasePtr(buf):
......
......@@ -2,7 +2,6 @@ from __future__ import annotations
from dataclasses import dataclass
from ..field import Field
from typing import Any
@dataclass(frozen=True)
......@@ -19,8 +18,12 @@ class UniqueSymbolProperty(PsSymbolProperty):
class ReductionSymbolProperty(UniqueSymbolProperty):
"""Property for symbols specifying the operation and initial value for a reduction."""
from ..backend.memory import PsSymbol
from ..backend.ast.expressions import PsExpression
op: str
init_val: Any # TODO: type?
init_val: PsExpression
orig_symbol: PsSymbol
@dataclass(frozen=True)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment