Skip to content
Snippets Groups Projects
Commit 9ded23ce authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed Wrong fString in Cuda Backend

parent 82af488a
No related branches found
No related tags found
1 merge request!159Fix: Wrong fString in Cuda Backend
...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -98,7 +98,7 @@ class CudaSympyPrinter(CustomSympyPrinter):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return f"__fsqrt_rn({tuple(self._print(a) for a in expr.args)})" return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return f"__frsqrt_rn({tuple(self._print(a) for a in expr.args)})" return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
return super()._print_Function(expr) return super()._print_Function(expr)
...@@ -12,6 +12,7 @@ def test_fast_sqrt(): ...@@ -12,6 +12,7 @@ def test_fast_sqrt():
assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts(expr).atoms(fast_sqrt)) == 1
assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1 assert len(insert_fast_sqrts([expr])[0].atoms(fast_sqrt)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu') ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_sqrts(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__fsqrt_rn' in code_str assert '__fsqrt_rn' in code_str
...@@ -21,6 +22,7 @@ def test_fast_sqrt(): ...@@ -21,6 +22,7 @@ def test_fast_sqrt():
ac = ps.AssignmentCollection([expr], []) ac = ps.AssignmentCollection([expr], [])
assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1 assert len(insert_fast_sqrts(ac).main_assignments[0].atoms(fast_inv_sqrt)) == 1
ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu') ast = ps.create_kernel(insert_fast_sqrts(ac), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__frsqrt_rn' in code_str assert '__frsqrt_rn' in code_str
...@@ -34,5 +36,6 @@ def test_fast_divisions(): ...@@ -34,5 +36,6 @@ def test_fast_divisions():
assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1 assert len(insert_fast_divisions(expr).atoms(fast_division)) == 1
ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu') ast = ps.create_kernel(ps.Assignment(g[0, 0], insert_fast_divisions(expr)), target='gpu')
ast.compile()
code_str = ps.get_code_str(ast) code_str = ps.get_code_str(ast)
assert '__fdividef' in code_str assert '__fdividef' in code_str
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment