Skip to content
Snippets Groups Projects
Commit f2260b3b authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix indexing

parent 3e1f6114
No related branches found
No related tags found
2 merge requests!353Draft: Generalise usage of Structs for nested array access,!341Refactor gpu indexing
...@@ -6,7 +6,7 @@ from typing import Tuple ...@@ -6,7 +6,7 @@ from typing import Tuple
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from pystencils.astnodes import Block, Conditional from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.typing import TypedSymbol, create_type from pystencils.typing import TypedSymbol, create_type
from pystencils.integer_functions import div_ceil, div_floor from pystencils.integer_functions import div_ceil, div_floor
from pystencils.sympyextensions import is_integer_sequence, prod from pystencils.sympyextensions import is_integer_sequence, prod
...@@ -169,7 +169,8 @@ class BlockIndexing(AbstractIndexing): ...@@ -169,7 +169,8 @@ class BlockIndexing(AbstractIndexing):
def call_parameters(self, arr_shape): def call_parameters(self, arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape) numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = [len(range(*s.indices(s.stop))) for s in numeric_iteration_slice] widths = [s.stop - s.start for s in numeric_iteration_slice]
widths = [w if w > 0 else 1 for w in widths]
extend_bs = (1,) * (3 - len(self._block_size)) extend_bs = (1,) * (3 - len(self._block_size))
block_size = self._block_size + extend_bs block_size = self._block_size + extend_bs
...@@ -192,7 +193,7 @@ class BlockIndexing(AbstractIndexing): ...@@ -192,7 +193,7 @@ class BlockIndexing(AbstractIndexing):
def guard(self, kernel_content, arr_shape): def guard(self, kernel_content, arr_shape):
arr_shape = arr_shape[:self._dim] arr_shape = arr_shape[:self._dim]
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape) numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
end = [s.stop for s in numeric_iteration_slice] end = [s.stop if s.stop != 0 else 1 for s in numeric_iteration_slice]
conditions = [c < e for c, e in zip(self.coordinates, end)] conditions = [c < e for c, e in zip(self.coordinates, end)]
for cuda_index, iter_slice in zip(self.cuda_indices, self._iteration_space): for cuda_index, iter_slice in zip(self.cuda_indices, self._iteration_space):
......
...@@ -69,6 +69,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection], ...@@ -69,6 +69,8 @@ def create_cuda_kernel(assignments: Union[AssignmentCollection, NodeCollection],
else: else:
iteration_space = normalize_slice(iteration_slice, common_shape) iteration_space = normalize_slice(iteration_slice, common_shape)
iteration_space = tuple([s if isinstance(s, slice) else slice(s, s, 1) for s in iteration_space])
indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout) indexing = indexing_creator(iteration_space=iteration_space, data_layout=common_field.layout)
coord_mapping = indexing.coordinates coord_mapping = indexing.coordinates
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment