Skip to content
Snippets Groups Projects
Commit 935beb55 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Move kernel AST modifications for reductions to distinct function

parent 8394f0f4
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -19,6 +19,7 @@ from .properties import PsSymbolProperty, FieldBasePtr ...@@ -19,6 +19,7 @@ from .properties import PsSymbolProperty, FieldBasePtr
from .parameters import Parameter from .parameters import Parameter
from .functions import Lambda from .functions import Lambda
from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration from .gpu_indexing import GpuIndexing, GpuLaunchConfiguration
from ..backend.kernelcreation.context import ReductionInfo
from ..field import Field from ..field import Field
from ..types import PsIntegerType, PsScalarType from ..types import PsIntegerType, PsScalarType
...@@ -192,28 +193,7 @@ class DefaultKernelCreationDriver: ...@@ -192,28 +193,7 @@ class DefaultKernelCreationDriver:
# Extensions for reductions # Extensions for reductions
for symbol, reduction_info in self._ctx.symbols_reduction_info.items(): for symbol, reduction_info in self._ctx.symbols_reduction_info.items():
typify = Typifier(self._ctx) self._modify_kernel_ast_for_reductions(symbol, reduction_info, kernel_ast)
symbol_expr = typify(PsSymbolExpr(symbol))
ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
init_val = typify(reduction_info.init_val)
ptr_access = PsMemAcc(
ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
)
write_back_ptr = PsCall(
PsReductionFunction(
ReductionFunctions.WriteBackToPtr, reduction_info.op
),
[ptr_symbol_expr, symbol_expr],
)
# declare and init local copy with neutral element
prepend_ast = [PsDeclaration(symbol_expr, init_val)]
# write back result to reduction target variable
append_ast = [PsAssignment(ptr_access, write_back_ptr)]
kernel_ast.statements = prepend_ast + kernel_ast.statements
kernel_ast.statements += append_ast
# Target-Specific optimizations # Target-Specific optimizations
if self._target.is_cpu(): if self._target.is_cpu():
...@@ -315,6 +295,35 @@ class DefaultKernelCreationDriver: ...@@ -315,6 +295,35 @@ class DefaultKernelCreationDriver:
return kernel_body return kernel_body
def _modify_kernel_ast_for_reductions(self,
symbol: PsSymbol,
reduction_info: ReductionInfo,
kernel_ast: PsBlock):
# typify local symbol and write-back pointer expressions and initial value
typify = Typifier(self._ctx)
symbol_expr = typify(PsSymbolExpr(symbol))
ptr_symbol_expr = typify(PsSymbolExpr(reduction_info.ptr_symbol))
init_val = typify(reduction_info.init_val)
ptr_access = PsMemAcc(
ptr_symbol_expr, PsConstantExpr(PsConstant(0, self._ctx.index_dtype))
)
write_back_ptr = PsCall(
PsReductionFunction(
ReductionFunctions.WriteBackToPtr, reduction_info.op
),
[ptr_symbol_expr, symbol_expr],
)
# declare and init local copy with neutral element
prepend_ast = [PsDeclaration(symbol_expr, init_val)]
# write back result to reduction target variable
append_ast = [PsAssignment(ptr_access, write_back_ptr)]
# modify AST
kernel_ast.statements = prepend_ast + kernel_ast.statements
kernel_ast.statements += append_ast
def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock: def _transform_for_cpu(self, kernel_ast: PsBlock) -> PsBlock:
canonicalize = CanonicalizeSymbols(self._ctx, True) canonicalize = CanonicalizeSymbols(self._ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment