Skip to content
Snippets Groups Projects
Commit 4465a595 authored by Markus Holzer's avatar Markus Holzer
Browse files

Added test case for vectorisation

parent cd3a2f3e
No related branches found
No related tags found
No related merge requests found
...@@ -298,7 +298,7 @@ class CBackend: ...@@ -298,7 +298,7 @@ class CBackend:
return node.get_code(self._dialect, self._vector_instruction_set) return node.get_code(self._dialect, self._vector_instruction_set)
def _print_SourceCodeComment(self, node): def _print_SourceCodeComment(self, node):
return "/* " + node.text + " */" return f"/* {node.text } */"
def _print_EmptyLine(self, node): def _print_EmptyLine(self, node):
return "" return ""
...@@ -316,7 +316,7 @@ class CBackend: ...@@ -316,7 +316,7 @@ class CBackend:
result = f"if ({condition_expr})\n{true_block} " result = f"if ({condition_expr})\n{true_block} "
if node.false_block: if node.false_block:
false_block = self._print_Block(node.false_block) false_block = self._print_Block(node.false_block)
result += "else " + false_block result += f"else {false_block}"
return result return result
...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -336,7 +336,7 @@ class CustomSympyPrinter(CCodePrinter):
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) return self._typed_number(expr.evalf(), get_type_of_expression(expr))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return f"({self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False))})"
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0:
return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})" return f"1 / ({self._print(sp.Mul(*([expr.base] * -expr.exp), evaluate=False))})"
else: else:
......
...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -104,7 +104,6 @@ class CudaSympyPrinter(CustomSympyPrinter):
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})" return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given" assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})" return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -11,9 +11,9 @@ def test_fast_sqrt(): ...@@ -11,9 +11,9 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') ast_gpu = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile() ast_gpu.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast_gpu)
assert '__fsqrt_rn' in code_str assert '__fsqrt_rn' in code_str
expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0])) expr = ps.Assignment(sp.Symbol("tmp"), 3 / sp.sqrt(f[0, 0] + f[1, 0]))
...@@ -21,9 +21,9 @@ def test_fast_sqrt(): ...@@ -21,9 +21,9 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], []) ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') ast_gpu = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile() ast_gpu.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast_gpu)
assert '__frsqrt_rn' in code_str assert '__frsqrt_rn' in code_str
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment