diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 135f6fb85866c486626bf965bbc475c9a3accf18..5dc1d7df8260fece0875511ebd1f29b778dfc11f 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -446,7 +446,10 @@ class SfgBasicComposer(SfgIComposer): return SfgSwitchBuilder(switch_arg) def map_field( - self, field: Field, index_provider: IFieldExtraction | SrcField + self, + field: Field, + index_provider: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, ) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes and strides should be extracted. @@ -454,8 +457,11 @@ class SfgBasicComposer(SfgIComposer): Args: field: The pystencils field to be mapped src_object: A `IFieldIndexingProvider` object representing a field data structure. + cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols """ - return SfgDeferredFieldMapping(field, index_provider) + return SfgDeferredFieldMapping( + field, index_provider, cast_indexing_symbols=cast_indexing_symbols + ) def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): deps = depends(expr) diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index 851a981862cb900442d92624bbaef0dbece68f89..91577435888599d8d494c628b1aac75721419339 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -9,7 +9,7 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import Field, TypedSymbol -from pystencils.types import deconstify +from pystencils.types import deconstify, PsType from pystencils.backend.kernelfunction import ( FieldPointerParam, FieldShapeParam, @@ -20,7 +20,7 @@ from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from ..ir.source_components import SfgSymbolLike -from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector +from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector, AugExpr if TYPE_CHECKING: from ..context import SfgContext @@ -244,13 +244,19 @@ class SfgDeferredParamSetter(SfgDeferredNode): class SfgDeferredFieldMapping(SfgDeferredNode): - def __init__(self, psfield: Field, extraction: IFieldExtraction | SrcField): + def __init__( + self, + psfield: Field, + extraction: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, + ): self._field = psfield self._extraction: IFieldExtraction = ( extraction if isinstance(extraction, IFieldExtraction) else extraction.get_extraction() ) + self._cast_indexing_symbols = cast_indexing_symbols # type: ignore def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: @@ -298,6 +304,12 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ) ) + def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr: + if self._cast_indexing_symbols: + return AugExpr(target_type).bind("{}( {} )", deconstify(target_type).c_string(), expr) + else: + return expr + def get_shape(coord, symb: SfgSymbolLike | int): expr = self._extraction.size(coord) @@ -307,6 +319,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ) if isinstance(symb, SfgSymbolLike): + expr = maybe_cast(expr, symb.dtype) return SfgStatements( f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends ) @@ -322,6 +335,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): ) if isinstance(symb, SfgSymbolLike): + expr = maybe_cast(expr, symb.dtype) return SfgStatements( f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends ) diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 481922e728549e48550e91b562b6099ec9b8c094..32b4754f89e32a6e5d63c0e27aece6eea83c0235 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -251,7 +251,7 @@ class AugExpr: self._bound = expr return self - def _is_bound(self) -> bool: + def is_bound(self) -> bool: return self._bound is not None