Something went wrong on our end. Please try again!
Draft: Loop counter dependent kernels: Vector casts and smaller fixes
Compare changes
+ 36
− 42
@@ -14,7 +14,7 @@ from sympy.functions.elementary.hyperbolic import HyperbolicFunction
@@ -14,7 +14,7 @@ from sympy.functions.elementary.hyperbolic import HyperbolicFunction
@@ -107,7 +107,7 @@ def get_headers(ast_node: Node) -> Set[str]:
@@ -107,7 +107,7 @@ def get_headers(ast_node: Node) -> Set[str]:
@@ -330,8 +330,8 @@ class CBackend:
@@ -330,8 +330,8 @@ class CBackend:
@@ -610,24 +610,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -610,24 +610,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -638,12 +634,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -638,12 +634,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -652,16 +643,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -652,16 +643,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
@@ -681,7 +671,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -681,7 +671,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -690,17 +681,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -690,17 +681,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -710,7 +702,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -710,7 +702,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -757,15 +748,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -757,15 +748,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
@@ -777,19 +766,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -777,19 +766,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
@@ -798,8 +790,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -798,8 +790,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -813,7 +805,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -813,7 +805,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -821,6 +813,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -821,6 +813,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -855,19 +849,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -855,19 +849,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):