Skip to content
Snippets Groups Projects
Commit 13569a61 authored by Richard Angersbach's avatar Richard Angersbach
Browse files

Fix lint

parent 71881893
No related branches found
No related tags found
1 merge request!438Reduction Support
...@@ -92,11 +92,9 @@ class PsVecHorizontal(PsBinOp, PsVectorOp): ...@@ -92,11 +92,9 @@ class PsVecHorizontal(PsBinOp, PsVectorOp):
def structurally_equal(self, other: PsAstNode) -> bool: def structurally_equal(self, other: PsAstNode) -> bool:
if not isinstance(other, PsVecHorizontal): if not isinstance(other, PsVecHorizontal):
return False return False
return ( return (super().structurally_equal(other)
super().structurally_equal(other)
and self._lanes == other._lanes and self._lanes == other._lanes
and self._reduction_op == other._reduction_op and self._reduction_op == other._reduction_op)
)
class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp): class PsVecMemAcc(PsExpression, PsLvalue, PsVectorOp):
......
...@@ -89,7 +89,7 @@ class CudaPlatform(GenericGpu): ...@@ -89,7 +89,7 @@ class CudaPlatform(GenericGpu):
if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max): if isinstance(dtype, PsScalarType) and func in (NumericLimitsFunctions.Min, NumericLimitsFunctions.Max):
assert isinstance(dtype, PsIeeeFloatType) assert isinstance(dtype, PsIeeeFloatType)
defines = { NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY" } defines = {NumericLimitsFunctions.Min: "NEG_INFINITY", NumericLimitsFunctions.Max: "POS_INFINITY"}
return PsLiteralExpr(PsLiteral(defines[func], dtype)) return PsLiteralExpr(PsLiteral(defines[func], dtype))
...@@ -170,8 +170,8 @@ class CudaPlatform(GenericGpu): ...@@ -170,8 +170,8 @@ class CudaPlatform(GenericGpu):
case ReductionOp.Sub: case ReductionOp.Sub:
# workaround for unsupported atomicSub: use atomic add # workaround for unsupported atomicSub: use atomic add
# similar to OpenMP reductions: local copies (negative sign) are added at the end # similar to OpenMP reductions: local copies (negative sign) are added at the end
call.function = CFunction(f"atomicAdd", [ptr_expr.dtype, symbol_expr.dtype], call.function = CFunction("atomicAdd", [ptr_expr.dtype, symbol_expr.dtype],
PsCustomType("void")) PsCustomType("void"))
call.args = (ptr_expr, symbol_expr) call.args = (ptr_expr, symbol_expr)
case _: case _:
call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype], call.function = CFunction(f"atomic{op.name}", [ptr_expr.dtype, symbol_expr.dtype],
......
...@@ -240,14 +240,14 @@ class LoopVectorizer: ...@@ -240,14 +240,14 @@ class LoopVectorizer:
) )
return PsBlock( return PsBlock(
simd_init_local_reduction_vars + simd_init_local_reduction_vars
[ + [
simd_stop_decl, simd_stop_decl,
simd_step_decl, simd_step_decl,
simd_loop simd_loop
] + ]
simd_writeback_local_reduction_vars + + simd_writeback_local_reduction_vars
[ + [
trailing_start_decl, trailing_start_decl,
trailing_loop, trailing_loop,
] ]
...@@ -258,13 +258,13 @@ class LoopVectorizer: ...@@ -258,13 +258,13 @@ class LoopVectorizer:
case LoopVectorizer.TrailingItersTreatment.NONE: case LoopVectorizer.TrailingItersTreatment.NONE:
return PsBlock( return PsBlock(
simd_init_local_reduction_vars + simd_init_local_reduction_vars
[ + [
simd_stop_decl, simd_stop_decl,
simd_step_decl, simd_step_decl,
simd_loop, simd_loop,
] + ]
simd_writeback_local_reduction_vars + simd_writeback_local_reduction_vars
) )
@overload @overload
......
from operator import truediv, mul, sub, add
from .backend.ast.expressions import PsExpression, PsCall, PsAdd, PsSub, PsMul, PsDiv from .backend.ast.expressions import PsExpression, PsCall, PsAdd, PsSub, PsMul, PsDiv
from .backend.exceptions import FreezeError from .backend.exceptions import FreezeError
from .backend.functions import PsMathFunction, MathFunctions from .backend.functions import PsMathFunction, MathFunctions
......
...@@ -11,7 +11,6 @@ except ImportError: ...@@ -11,7 +11,6 @@ except ImportError:
from ..codegen import Target from ..codegen import Target
from ..field import FieldType from ..field import FieldType
from ..types import PsType, PsPointerType
from .jit import JitBase, JitError, KernelWrapper from .jit import JitBase, JitError, KernelWrapper
from ..codegen import ( from ..codegen import (
Kernel, Kernel,
...@@ -19,7 +18,7 @@ from ..codegen import ( ...@@ -19,7 +18,7 @@ from ..codegen import (
Parameter, Parameter,
) )
from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr from ..codegen.properties import FieldShape, FieldStride, FieldBasePtr
from ..types import PsStructType, PsPointerType from ..types import PsType, PsStructType, PsPointerType
from ..include import get_pystencils_include_path from ..include import get_pystencils_include_path
......
...@@ -19,14 +19,18 @@ class ReductionAssignment(AssignmentBase): ...@@ -19,14 +19,18 @@ class ReductionAssignment(AssignmentBase):
Attributes: Attributes:
=========== ===========
binop : CompoundOp reduction_op : ReductionOp
Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc. Enum for binary operation being applied in the assignment, such as "Add" for "+", "Sub" for "-", etc.
""" """
reduction_op = None # type: ReductionOp _reduction_op = None # type: ReductionOp
@property @property
def reduction_op(self): def reduction_op(self):
return self.reduction_op return self._reduction_op
@reduction_op.setter
def reduction_op(self, op):
self._reduction_op = op
class AddReductionAssignment(ReductionAssignment): class AddReductionAssignment(ReductionAssignment):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment