Skip to content
Snippets Groups Projects
Commit f71ce708 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Enforce usage of typed symbols for reductions

parent c73deaf6
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -176,6 +176,8 @@ class FreezeExpressions: ...@@ -176,6 +176,8 @@ class FreezeExpressions:
return PsAssignment(lhs, binop_str_to_expr(expr.op[0], lhs.clone(), rhs)) return PsAssignment(lhs, binop_str_to_expr(expr.op[0], lhs.clone(), rhs))
def map_ReducedAssignment(self, expr: ReducedAssignment): def map_ReducedAssignment(self, expr: ReducedAssignment):
assert isinstance(expr.lhs, TypedSymbol)
lhs = self.visit(expr.lhs) lhs = self.visit(expr.lhs)
rhs = self.visit(expr.rhs) rhs = self.visit(expr.rhs)
...@@ -183,7 +185,7 @@ class FreezeExpressions: ...@@ -183,7 +185,7 @@ class FreezeExpressions:
assert isinstance(lhs, PsSymbolExpr) assert isinstance(lhs, PsSymbolExpr)
orig_lhs_symb = lhs.symbol orig_lhs_symb = lhs.symbol
dtype = rhs.dtype # TODO: kernel with (implicit) up/downcasts? dtype = lhs.dtype
assert isinstance(dtype, PsNumericType) assert isinstance(dtype, PsNumericType)
......
...@@ -24,7 +24,7 @@ def test_reduction(dtype, op): ...@@ -24,7 +24,7 @@ def test_reduction(dtype, op):
gpu_avail = False gpu_avail = False
x = ps.fields(f'x: {dtype}[1d]') x = ps.fields(f'x: {dtype}[1d]')
w = sp.Symbol("w") w = ps.TypedSymbol("w", dtype)
# kernel with reduction assignment # kernel with reduction assignment
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment