diff --git a/pystencils/gpu/indexing.py b/pystencils/gpu/indexing.py index 54a5acff14fcc82537cf06e1b4e83338b9e0a796..83add2f85898bc966b75d3b0845a8424b8559b7e 100644 --- a/pystencils/gpu/indexing.py +++ b/pystencils/gpu/indexing.py @@ -6,7 +6,7 @@ from typing import Tuple import sympy as sp from sympy.core.cache import cacheit -from pystencils.astnodes import Block, Conditional +from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment from pystencils.typing import TypedSymbol, create_type from pystencils.integer_functions import div_ceil, div_floor from pystencils.sympyextensions import is_integer_sequence, prod @@ -169,7 +169,8 @@ class BlockIndexing(AbstractIndexing): def call_parameters(self, arr_shape): numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape) - widths = [len(range(*s.indices(s.stop))) for s in numeric_iteration_slice] + widths = [s.stop - s.start for s in numeric_iteration_slice] + widths = [w if w > 0 else 1 for w in widths] extend_bs = (1,) * (3 - len(self._block_size)) block_size = self._block_size + extend_bs @@ -192,7 +193,7 @@ class BlockIndexing(AbstractIndexing): def guard(self, kernel_content, arr_shape): arr_shape = arr_shape[:self._dim] numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape) - end = [s.stop for s in numeric_iteration_slice] + end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice] conditions = [c < e for c, e in zip(self.coordinates, end)] for cuda_index, iter_slice in zip(self.cuda_indices, self._iteration_space): diff --git a/pystencils/gpu/kernelcreation.py b/pystencils/gpu/kernelcreation.py index e8ac135c3681b7284ac0e048fa6cc1104f65e640..45a3ae52f28f6c3f847972b6f13f87eb7d06863f 100644 --- a/pystencils/gpu/kernelcreation.py +++ b/pystencils/gpu/kernelcreation.py @@ -69,6 +69,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], else: iteration_space = normalize_slice(iteration_slice, common_shape) + iteration_space = tuple([s if isinstance(s, slice) else slice(s, s, 1) for s in iteration_space]) + indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) coord_mapping = indexing.coordinates