From 82997bedc617b08f12fc7c46a4bf0981e5c2fad4 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 18 Nov 2024 16:48:39 +0100 Subject: [PATCH] pystencils API updates & features for sweep gen - Fix type printing after changes in pystencils - Introduce casting of indexing symbols in field mapping - Extend class composer's constructor builder to allow incremental building - Introduce a utility for stripping pointers and refs from a type Squashed commit of the following: commit 6d54f2ca471b07b3b4761d1af5fe03a0267cf27d Author: Frederik Hennig <frederik.hennig@fau.de> Date: Mon Nov 18 16:47:18 2024 +0100 fix a doctest commit 2e54c7a022fe6dd0d4f9984090c6c87c4aeae499 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Fri Nov 15 15:37:49 2024 +0100 Fix data type printing commit 1397bcb25b86815b6bce64cd997ca91747cd4588 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Thu Nov 7 14:51:10 2024 +0100 some minor API changes commit 2ba2fd8d2914957d183e3087d1cd6e65c3ce546a Author: Frederik Hennig <frederik.hennig@fau.de> Date: Wed Nov 6 15:29:36 2024 +0100 Add `parameters` property to SfgClassComposer commit 1a30d20218e40ef9d88d7ae0dd4afac80cb1c96e Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Oct 29 17:04:19 2024 +0100 Extend ConstructorBuilder to allow incremental addition of parameters. Fix test cases for PPing. commit d0b8fff973dbe71d3c88f2e437a6a2767ae7cb50 Merge: 2977b58 d3e347f Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Oct 29 09:20:23 2024 +0100 Merge branch 'master' into lbwelding-features commit 2977b58c3c6a71353ee51b4c834692e006ef34a6 Author: Frederik Hennig <frederik.hennig@fau.de> Date: Tue Oct 29 09:17:30 2024 +0100 Introduce casts to indexing symbols in field extraction --- src/pystencilssfg/composer/basic_composer.py | 16 +++++--- src/pystencilssfg/composer/class_composer.py | 19 +++++++-- src/pystencilssfg/emission/printers.py | 8 ++-- src/pystencilssfg/extensions/sycl.py | 10 ++--- src/pystencilssfg/ir/postprocessing.py | 30 ++++++++++---- src/pystencilssfg/ir/source_components.py | 4 +- src/pystencilssfg/lang/__init__.py | 3 +- src/pystencilssfg/lang/expressions.py | 43 ++++++++++++-------- src/pystencilssfg/lang/types.py | 12 +++++- tests/ir/test_postprocessing.py | 6 +-- 10 files changed, 101 insertions(+), 50 deletions(-) diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 135f6fb..15177e6 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -385,7 +385,7 @@ class SfgBasicComposer(SfgIComposer): args_str = ", ".join(str(arg) for arg in args) deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) return SfgStatements( - f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + f"{lhs_var.dtype.c_string()} {lhs_var.name} {{ {args_str} }};", (lhs_var,), deps, ) @@ -412,7 +412,7 @@ class SfgBasicComposer(SfgIComposer): You can look at the expression's dependencies: >>> sorted(expr.depends, key=lambda v: v.name) - [x: float, y: float, z: float] + [x: float32, y: float32, z: float32] If you use an existing expression to create a larger one, the new expression inherits all variables from its parts: @@ -421,7 +421,7 @@ class SfgBasicComposer(SfgIComposer): >>> expr2 x + y * z + w >>> sorted(expr2.depends, key=lambda v: v.name) - [w: float, x: float, y: float, z: float] + [w: float32, x: float32, y: float32, z: float32] """ return AugExpr.format(fmt, *deps, **kwdeps) @@ -446,7 +446,10 @@ class SfgBasicComposer(SfgIComposer): return SfgSwitchBuilder(switch_arg) def map_field( - self, field: Field, index_provider: IFieldExtraction | SrcField + self, + field: Field, + index_provider: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, ) -> SfgDeferredFieldMapping: """Map a pystencils field to a field data structure, from which pointers, sizes and strides should be extracted. @@ -454,8 +457,11 @@ class SfgBasicComposer(SfgIComposer): Args: field: The pystencils field to be mapped src_object: A `IFieldIndexingProvider` object representing a field data structure. + cast_indexing_symbols: Whether to always introduce explicit casts for indexing symbols """ - return SfgDeferredFieldMapping(field, index_provider) + return SfgDeferredFieldMapping( + field, index_provider, cast_indexing_symbols=cast_indexing_symbols + ) def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): deps = depends(expr) diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index bd90678..1f4c486 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -8,6 +8,7 @@ from ..lang import ( VarLike, ExprLike, asvar, + SfgVar, ) from ..ir.source_components import ( @@ -72,16 +73,28 @@ class SfgClassComposer(SfgComposerMixIn): """ def __init__(self, *params: VarLike): - self._params = tuple(asvar(p) for p in params) + self._params = list(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None - def init(self, var: VarLike): + def add_param(self, param: VarLike, at: int | None = None): + if at is None: + self._params.append(asvar(param)) + else: + self._params.insert(at, asvar(param)) + + @property + def parameters(self) -> list[SfgVar]: + return self._params + + def init(self, var: VarLike | str): """Add an initialization expression to the constructor's initializer list.""" + member = var if isinstance(var, str) else asvar(var) + def init_sequencer(*args: ExprLike): expr = ", ".join(str(arg) for arg in args) - initializer = f"{asvar(var)}{{ {expr} }}" + initializer = f"{member}{{ {expr} }}" self._initializers.append(initializer) return self diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 9337161..c562bf7 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -66,7 +66,7 @@ class SfgGeneralPrinter: def param_list(self, func: SfgFunction) -> str: params = sorted(list(func.parameters), key=lambda p: p.name) - return ", ".join(f"{param.dtype} {param.name}" for param in params) + return ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) class SfgHeaderPrinter(SfgGeneralPrinter): @@ -113,7 +113,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgFunction) def function(self, func: SfgFunction): params = sorted(list(func.parameters), key=lambda p: p.name) - param_list = ", ".join(f"{param.dtype} {param.name}" for param in params) + param_list = ", ".join(f"{param.dtype.c_string()} {param.name}" for param in params) return f"{func.return_type} {func.name} ( {param_list} );" @visit.case(SfgClass) @@ -149,7 +149,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgConstructor) def sfg_constructor(self, constr: SfgConstructor): code = f"{constr.owning_class.class_name} (" - code += ", ".join(f"{param.dtype} {param.name}" for param in constr.parameters) + code += ", ".join(f"{param.dtype.c_string()} {param.name}" for param in constr.parameters) code += ")\n" if constr.initializers: code += " : " + ", ".join(constr.initializers) + "\n" @@ -161,7 +161,7 @@ class SfgHeaderPrinter(SfgGeneralPrinter): @visit.case(SfgMemberVariable) def sfg_member_var(self, var: SfgMemberVariable): - return f"{var.dtype} {var.name};" + return f"{var.dtype.c_string()} {var.name};" @visit.case(SfgMethod) def sfg_method(self, method: SfgMethod): diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index 4ee4991..3cb0c1c 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -14,7 +14,7 @@ from ..composer import ( SfgComposer, SfgComposerMixIn, ) -from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar +from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, @@ -73,7 +73,7 @@ class SyclHandler(AugExpr): id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -117,7 +117,7 @@ class SyclGroup(AugExpr): id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>") - def filter_id(param: SfgKernelParamVar) -> bool: + def filter_id(param: SfgVar) -> bool: return ( isinstance(param.dtype, PsCustomType) and id_regex.search(param.dtype.c_string()) is not None @@ -131,7 +131,7 @@ class SyclGroup(AugExpr): comp.map_param( id_param, h_item, - f"{id_param.dtype} {id_param.name} = {h_item}.get_local_id();", + f"{id_param.dtype.c_string()} {id_param.name} = {h_item}.get_local_id();", ), SfgKernelCallNode(kernel), ) @@ -186,7 +186,7 @@ class SfgLambda: def get_code(self, ctx: SfgContext): captures = ", ".join(self._captures) - params = ", ".join(f"{p.dtype} {p.name}" for p in self._params) + params = ", ".join(f"{p.dtype.c_string()} {p.name}" for p in self._params) body = self._tree.get_code(ctx) body = ctx.codestyle.indent(body) rtype = ( diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index c33ec7a..638a55f 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -9,14 +9,14 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import Field -from pystencils.types import deconstify +from pystencils.types import deconstify, PsType from pystencils.backend.properties import FieldBasePtr, FieldShape, FieldStride from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements from ..ir.source_components import SfgKernelParamVar -from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector +from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector, AugExpr if TYPE_CHECKING: from ..context import SfgContext @@ -233,20 +233,26 @@ class SfgDeferredParamSetter(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: live_var = ppc.get_live_variable(self._lhs.name) if live_var is not None: - code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + code = f"{live_var.dtype.c_string()} {live_var.name} = {self._rhs_expr};" return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) class SfgDeferredFieldMapping(SfgDeferredNode): - def __init__(self, psfield: Field, extraction: IFieldExtraction | SrcField): + def __init__( + self, + psfield: Field, + extraction: IFieldExtraction | SrcField, + cast_indexing_symbols: bool = True, + ): self._field = psfield self._extraction: IFieldExtraction = ( extraction if isinstance(extraction, IFieldExtraction) else extraction.get_extraction() ) + self._cast_indexing_symbols = cast_indexing_symbols def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: # Find field pointer @@ -285,10 +291,16 @@ class SfgDeferredFieldMapping(SfgDeferredNode): expr = self._extraction.ptr() nodes.append( SfgStatements( - f"{ptr.dtype} {ptr.name} {{ {expr} }};", (ptr,), expr.depends + f"{ptr.dtype.c_string()} {ptr.name} {{ {expr} }};", (ptr,), expr.depends ) ) + def maybe_cast(expr: AugExpr, target_type: PsType) -> AugExpr: + if self._cast_indexing_symbols: + return AugExpr(target_type).bind("{}( {} )", deconstify(target_type).c_string(), expr) + else: + return expr + def get_shape(coord, symb: SfgKernelParamVar | str): expr = self._extraction.size(coord) @@ -299,8 +311,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) + expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -315,8 +328,9 @@ class SfgDeferredFieldMapping(SfgDeferredNode): if isinstance(symb, SfgKernelParamVar) and symb not in done: done.add(symb) + expr = maybe_cast(expr, symb.dtype) return SfgStatements( - f"{symb.dtype} {symb.name} {{ {expr} }};", (symb,), expr.depends + f"{symb.dtype.c_string()} {symb.name} {{ {expr} }};", (symb,), expr.depends ) else: return SfgStatements(f"/* {expr} == {symb} */", (), ()) @@ -341,7 +355,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): expr = self._vector.extract_component(idx) nodes.append( SfgStatements( - f"{param.dtype} {param.name} {{ {expr} }};", + f"{param.dtype.c_string()} {param.name} {{ {expr} }};", (param,), expr.depends, ) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 4398938..cf4d103 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -163,7 +163,7 @@ class SfgKernelHandle: self._namespace = namespace self._parameters = [SfgKernelParamVar(p) for p in parameters] - self._scalar_params: set[SfgKernelParamVar] = set() + self._scalar_params: set[SfgVar] = set() self._fields: set[Field] = set() for param in self._parameters: @@ -193,7 +193,7 @@ class SfgKernelHandle: return self._parameters @property - def scalar_parameters(self): + def scalar_parameters(self) -> set[SfgVar]: return self._scalar_params @property diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index d67ffa0..b5532bf 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -12,7 +12,7 @@ from .expressions import ( SrcVector, ) -from .types import Ref +from .types import Ref, strip_ptr_ref __all__ = [ "SfgVar", @@ -27,4 +27,5 @@ __all__ = [ "SrcField", "SrcVector", "Ref", + "strip_ptr_ref" ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index 481922e..c8ac0f4 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -174,23 +174,30 @@ class AugExpr: """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) - def bind(self, fmt: str, *deps, **kwdeps): - dependencies: set[SfgVar] = set() - - from pystencils.sympyextensions import is_constant - - for expr in chain(deps, kwdeps.values()): - if isinstance(expr, _ExprLike): - dependencies |= depends(expr) - elif isinstance(expr, sp.Expr) and not is_constant(expr): - raise ValueError( - f"Cannot parse SymPy expression as C++ expression: {expr}\n" - " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " - "since they contain symbols without type information." - ) - - code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, dependencies)) + def bind(self, fmt: str | AugExpr, *deps, **kwdeps): + if isinstance(fmt, AugExpr): + if bool(deps) or bool(kwdeps): + raise ValueError("Binding to another AugExpr does not permit additional arguments") + if fmt._bound is None: + raise ValueError("Cannot rebind to unbound AugExpr.") + self._bind(fmt._bound) + else: + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + + code = fmt.format(*deps, **kwdeps) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: @@ -251,7 +258,7 @@ class AugExpr: self._bound = expr return self - def _is_bound(self) -> bool: + def is_bound(self) -> bool: return self._bound is not None diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py index 6f23160..084f1d5 100644 --- a/src/pystencilssfg/lang/types.py +++ b/src/pystencilssfg/lang/types.py @@ -1,5 +1,5 @@ from typing import Any -from pystencils.types import PsType +from pystencils.types import PsType, PsPointerType class Ref(PsType): @@ -24,3 +24,13 @@ class Ref(PsType): def __repr__(self) -> str: return f"Ref({repr(self.base_type)})" + + +def strip_ptr_ref(dtype: PsType): + match dtype: + case Ref(): + return strip_ptr_ref(dtype.base_type) + case PsPointerType(): + return strip_ptr_ref(dtype.base_type) + case _: + return dtype diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py index 3030e12..6a38d91 100644 --- a/tests/ir/test_postprocessing.py +++ b/tests/ir/test_postprocessing.py @@ -109,7 +109,7 @@ def test_field_extraction(): khandle = sfg.kernels.create(set_constant) extraction = TestFieldExtraction("f") - call_tree = make_sequence(sfg.map_field(f, extraction), sfg.call(khandle)) + call_tree = make_sequence(sfg.map_field(f, extraction, cast_indexing_symbols=False), sfg.call(khandle)) pp = CallTreePostProcessing() free_vars = pp.get_live_variables(call_tree) @@ -143,8 +143,8 @@ def test_duplicate_field_shapes(): khandle = sfg.kernels.create(set_constant) call_tree = make_sequence( - sfg.map_field(g, TestFieldExtraction("g")), - sfg.map_field(f, TestFieldExtraction("f")), + sfg.map_field(g, TestFieldExtraction("g"), cast_indexing_symbols=False), + sfg.map_field(f, TestFieldExtraction("f"), cast_indexing_symbols=False), sfg.call(khandle), ) -- GitLab