Skip to content
Snippets Groups Projects

Refactor gpu indexing

Merged Markus Holzer requested to merge holzer/pystencils:indexing into master
All threads resolved!
3 files
+ 26
13
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 13
7
@@ -6,7 +6,7 @@ from typing import Tuple
@@ -6,7 +6,7 @@ from typing import Tuple
import sympy as sp
import sympy as sp
from sympy.core.cache import cacheit
from sympy.core.cache import cacheit
from pystencils.astnodes import Block, Conditional, LoopOverCoordinate, SympyAssignment
from pystencils.astnodes import Block, Conditional
from pystencils.typing import TypedSymbol, create_type
from pystencils.typing import TypedSymbol, create_type
from pystencils.integer_functions import div_ceil, div_floor
from pystencils.integer_functions import div_ceil, div_floor
from pystencils.sympyextensions import is_integer_sequence, prod
from pystencils.sympyextensions import is_integer_sequence, prod
@@ -170,7 +170,7 @@ class BlockIndexing(AbstractIndexing):
@@ -170,7 +170,7 @@ class BlockIndexing(AbstractIndexing):
def call_parameters(self, arr_shape):
def call_parameters(self, arr_shape):
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
numeric_iteration_slice = _get_numeric_iteration_slice(self._iteration_space, arr_shape)
widths = [s.stop - s.start for s in numeric_iteration_slice]
widths = [s.stop - s.start for s in numeric_iteration_slice]
widths = [w if w > 0 else 1 for w in widths]
widths = [w if w != 0 else 1 for w in widths]
extend_bs = (1,) * (3 - len(self._block_size))
extend_bs = (1,) * (3 - len(self._block_size))
block_size = self._block_size + extend_bs
block_size = self._block_size + extend_bs
@@ -326,13 +326,19 @@ class LineIndexing(AbstractIndexing):
@@ -326,13 +326,19 @@ class LineIndexing(AbstractIndexing):
def _get_numeric_iteration_slice(iteration_slice, arr_shape):
def _get_numeric_iteration_slice(iteration_slice, arr_shape):
res = []
res = []
for slice_component, shape in zip(iteration_slice, arr_shape):
for slice_component, shape in zip(iteration_slice, arr_shape):
if not isinstance(slice_component.stop, int):
result_slice = slice_component
stop = slice_component.stop
if not isinstance(result_slice.start, int):
 
start = result_slice.start
 
assert len(start.free_symbols) == 1
 
start = start.subs({symbol: shape for symbol in start.free_symbols})
 
result_slice = slice(start, result_slice.stop, result_slice.step)
 
if not isinstance(result_slice.stop, int):
 
stop = result_slice.stop
assert len(stop.free_symbols) == 1
assert len(stop.free_symbols) == 1
stop = stop.subs({symbol: shape for symbol in stop.free_symbols})
stop = stop.subs({symbol: shape for symbol in stop.free_symbols})
res.append(slice(slice_component.start, stop, slice_component.step))
result_slice = slice(result_slice.start, stop, result_slice.step)
else:
assert isinstance(result_slice.step, int)
res.append(slice_component)
res.append(result_slice)
return res
return res
Loading