Skip to content
Snippets Groups Projects

Fix: Wrong fString in Cuda Backend

Merged Markus Holzer requested to merge holzer/pystencils:Fix_Wrong_fString into master
2 files
+ 13
8
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -33,10 +33,10 @@ class CudaBackend(CBackend):
@@ -33,10 +33,10 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda')
super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node):
def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];"
dtype = node.symbol.dtype
return code.format(dtype=node.symbol.dtype,
name = self.sympy_printer.doprint(node.symbol.name)
name=self.sympy_printer.doprint(node.symbol.name),
num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
num_elements='*'.join([str(s) for s in node.shared_mem.shape]))
return f"__shared__ {dtype} {name}[{num_elements}];"
@staticmethod
@staticmethod
def _print_ThreadBlockSynchronization(node):
def _print_ThreadBlockSynchronization(node):
@@ -45,6 +45,7 @@ class CudaBackend(CBackend):
@@ -45,6 +45,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node):
def _print_TextureDeclaration(self, node):
 
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4:
if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype),
str(node.texture.field.dtype),
@@ -96,9 +97,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
@@ -96,9 +97,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr):
def _print_Function(self, expr):
if isinstance(expr, fast_division):
if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 2, f"__fdividef has two argument, but {len(expr.args)} where given"
 
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt):
elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
 
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt):
elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args)
print(len(expr.args) == 1)
 
assert len(expr.args) == 1, f"__frsqrt_rn has only one argument, but {len(expr.args)} where given"
 
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr)
return super()._print_Function(expr)
Loading