Skip to content
Snippets Groups Projects
Commit 3e0daa67 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Propagate properties of reduction pointer symbols to kernel parameters

parent c6eedfcd
No related branches found
No related tags found
1 merge request!438Reduction Support
......@@ -5,7 +5,7 @@ from dataclasses import dataclass, replace
from .target import Target
from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable
from .parameters import Parameter
from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
......@@ -461,7 +461,8 @@ def _get_function_params(
props: set[PsSymbolProperty] = set()
for prop in symb.properties:
match prop:
# TODO: how to export reduction result (via pointer)?
case ReductionPointerVariable():
props.add(prop)
case FieldShape() | FieldStride():
props.add(prop)
case BufferBasePtr(buf):
......
......@@ -13,7 +13,7 @@ from ..codegen import (
Kernel,
Parameter,
)
from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..codegen.properties import FieldBasePtr, FieldShape, FieldStride, ReductionPointerVariable
from ..types import (
PsType,
PsUnsignedIntegerType,
......@@ -265,7 +265,10 @@ if( !kwargs || !PyDict_Check(kwargs) ) {{
return self._array_buffers[field]
def extract_scalar(self, param: Parameter) -> str:
if param not in self._scalar_extractions:
if any(isinstance(e, ReductionPointerVariable) for e in param.properties):
# TODO: implement
pass
elif param not in self._scalar_extractions:
extract_func = self._scalar_extractor(param.dtype)
code = self.TMPL_EXTRACT_SCALAR.format(
name=param.name,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment