Skip to content
Snippets Groups Projects
Commit c1c692a8 authored by Markus Holzer's avatar Markus Holzer
Browse files

First fix

parent e18aa374
No related branches found
No related tags found
No related merge requests found
Pipeline #51423 passed
......@@ -124,12 +124,18 @@ class BlockIndexing(AbstractIndexing):
self._symbolic_shape = [e if isinstance(e, sp.Basic) else None for e in field.spatial_shape]
self._compile_time_block_size = compile_time_block_size
@property
def cuda_indices(self):
block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM
indices = [block_index * bs + thread_idx
for block_index, bs, thread_idx in zip(BLOCK_IDX, block_size, THREAD_IDX)]
return indices[:self._dim]
@property
def coordinates(self):
offsets = _get_start_from_slice(self._iterationSlice)
block_size = self._block_size if self._compile_time_block_size else BLOCK_DIM
coordinates = [block_index * bs + thread_idx + off
for block_index, bs, thread_idx, off in zip(BLOCK_IDX, block_size, THREAD_IDX, offsets)]
coordinates = [c + off for c, off in zip(self.cuda_indices, offsets)]
return coordinates[:self._dim]
......@@ -159,8 +165,13 @@ class BlockIndexing(AbstractIndexing):
def guard(self, kernel_content, arr_shape):
arr_shape = arr_shape[:self._dim]
conditions = [c < end
for c, end in zip(self.coordinates, _get_end_from_slice(self._iterationSlice, arr_shape))]
end = _get_end_from_slice(self._iterationSlice, arr_shape)
conditions = [c < e for c, e in zip(self.coordinates, end)]
for cuda_index, iter_slice in zip(self.cuda_indices, self._iterationSlice):
if iter_slice.step > 1:
conditions.append(sp.Eq(sp.Mod(cuda_index, iter_slice.step), 0))
condition = conditions[0]
for c in conditions[1:]:
condition = sp.And(condition, c)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment