From ac4b31c141659373cece8cd63ef6dd465dbd2a4f Mon Sep 17 00:00:00 2001 From: Christoph Alt <christoph.alt@fau.de> Date: Tue, 8 Aug 2023 10:42:39 +0200 Subject: [PATCH] Updated import to new pystencils api --- pystencils_benchmark/benchmark_gpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pystencils_benchmark/benchmark_gpu.py b/pystencils_benchmark/benchmark_gpu.py index d68a31d..1c4e24c 100644 --- a/pystencils_benchmark/benchmark_gpu.py +++ b/pystencils_benchmark/benchmark_gpu.py @@ -6,10 +6,10 @@ from jinja2 import Environment, PackageLoader, StrictUndefined from pystencils.backends.cbackend import generate_c, get_headers from pystencils.astnodes import KernelFunction from pystencils.enums import Backend -from pystencils.data_types import get_base_type +from pystencils.typing import get_base_type from pystencils.sympyextensions import prod -from pystencils.transformations import get_common_shape -from pystencils.gpucuda import BlockIndexing +from pystencils.transformations import get_common_field +# from pystencils.gpucuda import BlockIndexing from pystencils_benchmark.enums import Compiler @@ -121,7 +121,7 @@ def kernel_main(kernels_ast: List[KernelFunction], timing: bool = True, cuda_blo fields.append((p.field_name, dtype, elements)) call_parameters.append(p.field_name) - common_shape = get_common_shape(kernel.fields_accessed) + common_shape = get_common_field(kernel.fields_accessed).shape indexing = kernel.indexing block_and_thread_numbers = indexing.call_parameters(common_shape) block_and_thread_numbers['block'] = tuple(int(i) for i in block_and_thread_numbers['block']) -- GitLab