Skip to content
Snippets Groups Projects
Commit c73deaf6 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix mypy errors and move binop mapping function

parent 3daaa5e5
No related branches found
No related tags found
1 merge request!438Reduction Support
Pipeline #72917 failed
......@@ -45,6 +45,7 @@ from .sympyextensions.reduction import (
MinReducedssignment,
MaxReducedssignment
)
from .binop_mapping import binop_str_to_expr
__all__ = [
"Field",
......@@ -75,6 +76,7 @@ __all__ = [
"inspect",
"AssignmentCollection",
"Assignment",
"binop_str_to_expr",
"AddAugmentedAssignment",
"AddReducedAssignment",
"SubReducedAssignment",
......
......@@ -13,7 +13,7 @@ from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.binop_mapping import binop_str_to_expr
from ...binop_mapping import binop_str_to_expr
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc, DynamicType
from ...sympyextensions.pointers import AddressOf, mem_acc
from ...sympyextensions.reduction import ReducedAssignment
......@@ -185,6 +185,8 @@ class FreezeExpressions:
orig_lhs_symb = lhs.symbol
dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts?
assert isinstance(dtype, PsNumericType)
# replace original symbol with pointer-based type used for export
orig_lhs_symb_as_ptr = PsSymbol(orig_lhs_symb.name, PsPointerType(dtype))
......@@ -196,7 +198,6 @@ class FreezeExpressions:
new_rhs: PsExpression = binop_str_to_expr(expr.op, new_lhs.clone(), rhs)
# match for reduction operation and set neutral init_val
new_rhs: PsExpression
init_val: PsExpression
match expr.op:
case "+":
......
from operator import truediv, mul, sub, add
from ..backend.ast.expressions import PsCall, PsExpression
from ..backend.exceptions import FreezeError
from ..backend.functions import MathFunctions, PsMathFunction
from .backend.ast.expressions import PsExpression, PsCall
from .backend.exceptions import FreezeError
from .backend.functions import PsMathFunction, MathFunctions
_available_operator_interface: set[str] = {'+', '-', '*', '/'}
......
......@@ -7,8 +7,8 @@ from .config import CreateKernelConfig, OpenMpConfig, VectorizationConfig, AUTO
from .kernel import Kernel, GpuKernel, GpuThreadsRange
from .properties import PsSymbolProperty, FieldShape, FieldStride, FieldBasePtr, ReductionPointerVariable
from .parameters import Parameter
from ..binop_mapping import binop_str_to_expr
from ..backend.ast.expressions import PsSymbolExpr, PsMemAcc, PsConstantExpr
from ..sympyextensions.binop_mapping import binop_str_to_expr
from ..types import create_numeric_type, PsIntegerType, PsScalarType
......@@ -155,14 +155,14 @@ class DefaultKernelCreationDriver:
self._intermediates.constants_eliminated = kernel_ast.clone()
# Init local reduction variable copy
for local_red, prop in self._ctx.local_reduction_symbols.items():
kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), prop.init_val)] + kernel_ast.statements
for local_red, local_prop in self._ctx.local_reduction_symbols.items():
kernel_ast.statements = [PsDeclaration(PsSymbolExpr(local_red), local_prop.init_val)] + kernel_ast.statements
# Write back result to reduction target variable
for red_ptr, prop in self._ctx.reduction_pointer_symbols.items():
for red_ptr, ptr_prop in self._ctx.reduction_pointer_symbols.items():
ptr_access = PsMemAcc(PsSymbolExpr(red_ptr), PsConstantExpr(PsConstant(0, self._ctx.index_dtype)))
kernel_ast.statements += [PsAssignment(
ptr_access, binop_str_to_expr(prop.op, ptr_access, PsSymbolExpr(prop.local_symbol)))]
ptr_access, binop_str_to_expr(ptr_prop.op, ptr_access, PsSymbolExpr(ptr_prop.local_symbol)))]
# Target-Specific optimizations
if self._cfg.target.is_cpu():
......
......@@ -2,7 +2,6 @@ from .astnodes import ConditionalFieldAccess
from .typed_sympy import TypedSymbol, CastFunc
from .pointers import mem_acc
from .reduction import reduced_assign
from .binop_mapping import binop_str_to_expr
from .math import (
prod,
......@@ -36,7 +35,6 @@ from .math import (
__all__ = [
"ConditionalFieldAccess",
"reduced_assign",
"binop_str_to_expr",
"TypedSymbol",
"CastFunc",
"mem_acc",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment