diff --git a/src/pystencils/codegen/gpu_indexing.py b/src/pystencils/codegen/gpu_indexing.py index bcffbbcaebb5a0222cb6f864755e4694fa9c674d..d5f0aead2cd962c0ca59dbc04f1049b2d5a5af6a 100644 --- a/src/pystencils/codegen/gpu_indexing.py +++ b/src/pystencils/codegen/gpu_indexing.py @@ -20,7 +20,7 @@ from ..backend.kernelcreation import ( ) from ..backend.platforms.cuda import ThreadMapping -from ..backend.ast.expressions import PsExpression +from ..backend.ast.expressions import PsExpression, PsIntDiv from math import prod from ..utils import ceil_to_multiple @@ -605,9 +605,10 @@ class GpuIndexing: # -> round block size in fastest moving dimension up to multiple of warp size rounded_block_size: PsExpression if self._assume_warp_aligned_block_size: - rounded_block_size = ceil_to_multiple( - work_items[0], - PsExpression.make(PsConstant(self._hw_props.warp_size, work_items[0].dtype))) + warp_size = self._ast_factory.parse_index(self._hw_props.warp_size) + rounded_block_size = self._ast_factory.parse_index( + PsIntDiv(work_items[0].clone() + warp_size.clone() - self._ast_factory.parse_index(1), + warp_size.clone()) * warp_size.clone()) else: rounded_block_size = work_items[0]