diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 135f6fb85866c486626bf965bbc475c9a3accf18..15177e6d768e3e568d8f4d30785d87344520b4b4 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer): args_str = ", ".join(str(arg) for arg in args) deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) return SfgStatements( - f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};", (lhs_var,), deps, ) @@ -412,7 +412,7 @@ class SfgBasicComposer(SfgIComposer): You can look at the expression's dependencies: >>> sorted(expr.depends, key=lambda v: v.name) - [x: float, y: float, z: float] + [x: float32, y: float32, z: float32] If you use an existing expression to create a larger one, the new expression inherits all variables from its parts: @@ -421,7 +421,7 @@ class SfgBasicComposer(SfgIComposer): >>> expr2 x + y * z + w >>> sorted(expr2.depends, key=lambda v: v.name) - [w: float, x: float, y: float, z: float] + [w: float32, x: float32, y: float32, z: float32] """ return AugExpr.format(fmt, *deps, **kwdeps) @@ -446,7 +446,10 @@ class SfgBasicComposer(SfgIComposer): return SfgSwitchBuilder(switch_arg) def map_field( - self, field: Field, index_provider: IFieldExtraction | SrcField + self, + field: Field, + index_provider: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, ) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes and strides should be extracted. @@ -454,8 +457,11 @@ class SfgBasicComposer(SfgIComposer): Args: field: The pystencils field to be mapped src_object: A `IFieldIndexingProvider` object representing a field data structure. + cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols """ - return SfgDeferredFieldMapping(field, index_provider) + return SfgDeferredFieldMapping( + field, index_provider, cast_indexing_symbols=cast_indexing_symbols + ) def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): deps = depends(expr) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index bd906782d6e074955c60221d9a8f4d9b15bae772..1f4c4865987c1f10ec133f95d11e4bccc1ef8b76 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -8,6 +8,7 @@ from ..lang import ( VarLike, ExprLike, asvar, + SfgVar, ) from ..ir.source_components import ( @@ -72,16 +73,28 @@ class SfgClassComposer(SfgComposerMixIn): """ def __init__(self, *params: VarLike): - self._params = tuple(asvar(p) for p in params) + self._params = list(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None - def init(self, var: VarLike): + def add_param(self, param: VarLike, at: int | None = None): + if at is None: + self._params.append(asvar(param)) + else: + self._params.insert(at, asvar(param)) + + @property + def parameters(self) -> list[SfgVar]: + return self._params + + def init(self, var: VarLike | str): """Add an initialization expression to the constructor's initializer list.""" + member = var if isinstance(var, str) else asvar(var) + def init_sequencer(*args: ExprLike): expr = ", ".join(str(arg) for arg in args) - initializer = f"{asvar(var)}{{ {expr} }}" + initializer = f"{member}{{ {expr} }}" self._initializers.append(initializer) return self diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 93371619e302932d97441ff069190ba956b0f04d..c562bf7bca5d59365c3a1cd5b9d77b8f7e7920f5 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -66,7 +66,7 @@ class SfgGeneralPrinter: def param_list(self, func: SfgFunction) -> str: params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype} {param.name}" for param in params) + return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) class SfgHeaderPrinter(SfgGeneralPrinter): @@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) return f"{func.return_type} {func.name} ( {param_list} );" @visit.case(SfgClass) @@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgConstructor) def sfg_constructor(self, constr: SfgConstructor): code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) code += ")\n" if constr.initializers: code += " : " + ", ".join(constr.initializers) + "\n" @@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMemberVariable) def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype} {var.name};" + return f"{var.dtype.c_string()} {var.name};" @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 4ee499126fe795f95284d9c40d02513f53333054..3cb0c1c5e50aa2b9557a176f3c541283641ad530 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -14,7 +14,7 @@ from ..composer import ( SfgComposer, SfgComposerMixIn, ) -from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar +from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, @@ -73,7 +73,7 @@ class SyclHandler(AugExpr): id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -117,7 +117,7 @@ class SyclGroup(AugExpr): id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -131,7 +131,7 @@ class SyclGroup(AugExpr): comp.map_param( id_param, h_item, - f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();", + f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();", ), SfgKernelCallNode(kernel), ) @@ -186,7 +186,7 @@ class SfgLambda: def get_code(self, ctx: SfgContext): captures = ", ".join(self._captures) - params = ", ".join(f"{p.dtype} {p.name}" for p in self._params) + params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) body = self._tree.get_code(ctx) body = ctx.codestyle.indent(body) rtype = ( diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index c33ec7ab282432d6f462c3316032bb8842cd7652..638a55f30f41f26f531a69a346b083dddd901797 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -9,14 +9,14 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import Field -from pystencils.types import deconstify +from pystencils.types import deconstify, PsType from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from ..ir.source_components import SfgKernelParamVar -from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector +from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector, AugExpr if TYPE_CHECKING: from ..context import SfgContext @@ -233,20 +233,26 @@ class SfgDeferredParamSetter(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: - code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};" return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) class SfgDeferredFieldMapping(SfgDeferredNode): - def __init__(self, psfield: Field, extraction: IFieldExtraction | SrcField): + def __init__( + self, + psfield: Field, + extraction: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, + ): self._field = psfield self._extraction: IFieldExtraction = ( extraction if isinstance(extraction, IFieldExtraction) else extraction.get_extraction() ) + self._cast_indexing_symbols = cast_indexing_symbols def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer @@ -285,10 +291,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction.ptr() nodes.append( SfgStatements( - f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends + f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends ) ) + def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr: + if self._cast_indexing_symbols: + return AugExpr(target_type).bind("{}( {} )", deconstify(target_type).c_string(), expr) + else: + return expr + def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction.size(coord) @@ -299,8 +311,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) + expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -315,8 +328,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) + expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -341,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): expr = self._vector.extract_component(idx) nodes.append( SfgStatements( - f"{param.dtype} {param.name} {{ {expr} }};", + f"{param.dtype.c_string()} {param.name} {{ {expr} }};", (param,), expr.depends, ) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 4398938327242e8e1c1fb2cfc7787c72fa0d3138..cf4d103a2a93e363e73282a31eb0b6523f510c6e 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -163,7 +163,7 @@ class SfgKernelHandle: self._namespace = namespace self._parameters = [SfgKernelParamVar(p) for p in parameters] - self._scalar_params: set[SfgKernelParamVar] = set() + self._scalar_params: set[SfgVar] = set() self._fields: set[Field] = set() for param in self._parameters: @@ -193,7 +193,7 @@ class SfgKernelHandle: return self._parameters @property - def scalar_parameters(self): + def scalar_parameters(self) -> set[SfgVar]: return self._scalar_params @property diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index d67ffa0c845b16c867064365cec8b5af5ffe6a2a..b5532bf77464dbc1e32f8449158ec4ae34418f30 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,7 +12,7 @@ from .expressions import ( SrcVector, ) -from .types import Ref +from .types import Ref, strip_ptr_ref __all__ = [ "SfgVar", @@ -27,4 +27,5 @@ __all__ = [ "SrcField", "SrcVector", "Ref", + "strip_ptr_ref" ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 481922e728549e48550e91b562b6099ec9b8c094..c8ac0f4cbc95c6af7f5b283852a344c7cc24cb45 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -174,23 +174,30 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str, *deps, **kwdeps): - dependencies: set[SfgVar] = set() - - from pystencils.sympyextensions import is_constant - - for expr in chain(deps, kwdeps.values()): - if isinstance(expr, _ExprLike): - dependencies |= depends(expr) - elif isinstance(expr, sp.Expr) and not is_constant(expr): - raise ValueError( - f"Cannot parse SymPy expression as C++ expression: {expr}\n" - " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " - "since they contain symbols without type information." - ) - - code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, dependencies)) + def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + if isinstance(fmt, AugExpr): + if bool(deps) or bool(kwdeps): + raise ValueError("Binding to another AugExpr does not permit additional arguments") + if fmt._bound is None: + raise ValueError("Cannot rebind to unbound AugExpr.") + self._bind(fmt._bound) + else: + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + + code = fmt.format(*deps, **kwdeps) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: @@ -251,7 +258,7 @@ class AugExpr: self._bound = expr return self - def _is_bound(self) -> bool: + def is_bound(self) -> bool: return self._bound is not None diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index 6f23160075050c6dfa33fd17636c2a3f826f263a..084f1d529a020b9796aeb82e208579f6f1aa5724 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,5 @@ from typing import Any -from pystencils.types import PsType +from pystencils.types import PsType, PsPointerType class Ref(PsType): @@ -24,3 +24,13 @@ class Ref(PsType): def __repr__(self) -> str: return f"Ref({repr(self.base_type)})" + + +def strip_ptr_ref(dtype: PsType): + match dtype: + case Ref(): + return strip_ptr_ref(dtype.base_type) + case PsPointerType(): + return strip_ptr_ref(dtype.base_type) + case _: + return dtype diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 3030e1294d784c88c11e5770c9af0bfe00c21333..6a38d91c5f21a074f09c0a51ca7db39c0e4cba5f 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -109,7 +109,7 @@ def test_field_extraction(): khandle = sfg.kernels.create(set_constant) extraction = TestFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle)) + call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -143,8 +143,8 @@ def test_duplicate_field_shapes(): khandle = sfg.kernels.create(set_constant) call_tree = make_sequence( - sfg.map_field(g, TestFieldExtraction("g")), - sfg.map_field(f, TestFieldExtraction("f")), + sfg.map_field(g, TestFieldExtraction("g"), cast_indexing_symbols=False), + sfg.map_field(f, TestFieldExtraction("f"), cast_indexing_symbols=False), sfg.call(khandle), )