Skip to content
Snippets Groups Projects
Commit ed8b7430 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fixed print Function in CUDA Backend

parent 9ded23ce
1 merge request!159Fix: Wrong fString in Cuda Backend
...@@ -52,7 +52,7 @@ pip install pystencils[interactive] ...@@ -52,7 +52,7 @@ pip install pystencils[interactive]
Without `[interactive]` you get a minimal version with very little dependencies. Without `[interactive]` you get a minimal version with very little dependencies.
All options: All options:
- `gpu`: use this if an Nvidia GPU is available and CUDA is installed - `gpu`: use this if an NVIDIA GPU is available and CUDA is installed
- `opencl`: basic OpenCL support (experimental) - `opencl`: basic OpenCL support (experimental)
- `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl - `alltrafos`: pulls in additional dependencies for loop simplification e.g. libisl
- `bench_db`: functionality to store benchmark result in object databases - `bench_db`: functionality to store benchmark result in object databases
......
...@@ -33,10 +33,11 @@ class CudaBackend(CBackend): ...@@ -33,10 +33,11 @@ class CudaBackend(CBackend):
super().__init__(sympy_printer, signature_only, dialect='cuda') super().__init__(sympy_printer, signature_only, dialect='cuda')
def _print_SharedMemoryAllocation(self, node): def _print_SharedMemoryAllocation(self, node):
code = "__shared__ {dtype} {name}[{num_elements}];" dtype = node.symbol.dtype
return code.format(dtype=node.symbol.dtype, name = self.sympy_printer.doprint(node.symbol.name)
name=self.sympy_printer.doprint(node.symbol.name), num_elements = '*'.join([str(s) for s in node.shared_mem.shape])
num_elements='*'.join([str(s) for s in node.shared_mem.shape])) code = f"__shared__ {dtype} {name}[{num_elements}];"
return code
@staticmethod @staticmethod
def _print_ThreadBlockSynchronization(node): def _print_ThreadBlockSynchronization(node):
...@@ -45,6 +46,7 @@ class CudaBackend(CBackend): ...@@ -45,6 +46,7 @@ class CudaBackend(CBackend):
def _print_TextureDeclaration(self, node): def _print_TextureDeclaration(self, node):
# TODO: use fStrings here
if node.texture.field.dtype.numpy_dtype.itemsize > 4: if node.texture.field.dtype.numpy_dtype.itemsize > 4:
code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % ( code = "texture<fp_tex_%s, cudaTextureType%iD, cudaReadModeElementType> %s;" % (
str(node.texture.field.dtype), str(node.texture.field.dtype),
...@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter): ...@@ -96,9 +98,13 @@ class CudaSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, fast_division): if isinstance(expr, fast_division):
return "__fdividef(%s, %s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 2, f"__fdividef has two arguments, but {len(expr.args)} where given"
return f"__fdividef({self._print(expr.args[0])}, {self._print(expr.args[1])})"
elif isinstance(expr, fast_sqrt): elif isinstance(expr, fast_sqrt):
return "__fsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) assert len(expr.args) == 1, f"__fsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__fsqrt_rn({self._print(expr.args[0])})"
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "__frsqrt_rn(%s)" % tuple(self._print(a) for a in expr.args) print(len(expr.args) == 1)
assert len(expr.args) == 1, f"__frsqrt_rn has one argument, but {len(expr.args)} where given"
return f"__frsqrt_rn({self._print(expr.args[0])})"
return super()._print_Function(expr) return super()._print_Function(expr)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment