Skip to content
Snippets Groups Projects
Commit e40ca9a1 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Lint

parent b93d6992
No related branches found
No related tags found
No related merge requests found
...@@ -213,7 +213,11 @@ class CBackend: ...@@ -213,7 +213,11 @@ class CBackend:
def _print_SympyAssignment(self, node): def _print_SympyAssignment(self, node):
if node.is_declaration: if node.is_declaration:
data_type = "const " + self._print(node.lhs.dtype) + " " if node.is_const else self._print(node.lhs.dtype) + " " if node.is_const:
prefix = 'const '
else:
prefix = ''
data_type = prefix + self._print(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs)) self.sympy_printer.doprint(node.rhs))
else: else:
......
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
from pystencils.backends.cbackend import generate_c
from pystencils.astnodes import Node
import pystencils.data_types import pystencils.data_types
from pystencils.astnodes import Node
from pystencils.backends.cbackend import generate_c
from pystencils.backends.cuda_backend import CudaBackend, CudaSympyPrinter
def generate_opencl(astnode: Node, signature_only: bool = False) -> str: def generate_opencl(astnode: Node, signature_only: bool = False) -> str:
"""Prints an abstract syntax tree node as CUDA code. """Prints an abstract syntax tree node as CUDA code.
...@@ -27,7 +28,6 @@ class OpenClBackend(CudaBackend): ...@@ -27,7 +28,6 @@ class OpenClBackend(CudaBackend):
super().__init__(sympy_printer, signature_only) super().__init__(sympy_printer, signature_only)
self._dialect = 'opencl' self._dialect = 'opencl'
def _print_Type(self, node): def _print_Type(self, node):
code = super()._print_Type(node) code = super()._print_Type(node)
if isinstance(node, pystencils.data_types.PointerType): if isinstance(node, pystencils.data_types.PointerType):
...@@ -57,4 +57,3 @@ class OpenClSympyPrinter(CudaSympyPrinter): ...@@ -57,4 +57,3 @@ class OpenClSympyPrinter(CudaSympyPrinter):
dimension = self.DIMENSION_MAPPING[dimension] dimension = self.DIMENSION_MAPPING[dimension]
function_name = self.INDEXING_FUNCTION_MAPPING[function_name] function_name = self.INDEXING_FUNCTION_MAPPING[function_name]
return f"{function_name}({dimension})" return f"{function_name}({dimension})"
...@@ -60,8 +60,8 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen ...@@ -60,8 +60,8 @@ def make_python_function(kernel_function_node, opencl_queue, opencl_ctx, argumen
indexing = kernel_function_node.indexing indexing = kernel_function_node.indexing
block_and_thread_numbers = indexing.call_parameters(shape) block_and_thread_numbers = indexing.call_parameters(shape)
block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block']) block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block'])
block_and_thread_numbers['grid'] = tuple(int(b*g) for (b, g) in zip(block_and_thread_numbers['block'], block_and_thread_numbers['grid'] = tuple(int(b * g) for (b, g) in zip(block_and_thread_numbers['block'],
block_and_thread_numbers['grid'])) block_and_thread_numbers['grid']))
args = _build_numpy_argument_list(parameters, full_arguments) args = _build_numpy_argument_list(parameters, full_arguments)
args = [a.data for a in args if hasattr(a, 'data')] args = [a.data for a in args if hasattr(a, 'data')]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment