From fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Sun, 21 Apr 2024 16:31:38 +0200 Subject: [PATCH] Introduce PsLiteral and PsLiteralExpr. --- docs/source/backend/index.rst | 2 +- .../backend/{symbols.rst => objects.rst} | 3 ++ src/pystencils/backend/ast/expressions.py | 38 +++++++++++++++- src/pystencils/backend/emission.py | 4 ++ src/pystencils/backend/functions.py | 13 ++++-- .../backend/kernelcreation/typification.py | 12 ++++++ src/pystencils/backend/literals.py | 43 +++++++++++++++++++ .../backend/platforms/generic_gpu.py | 12 +++--- .../transformations/eliminate_constants.py | 7 +-- .../hoist_loop_invariant_decls.py | 3 +- 10 files changed, 122 insertions(+), 15 deletions(-) rename docs/source/backend/{symbols.rst => objects.rst} (80%) create mode 100644 src/pystencils/backend/literals.py diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst index df194bde9..e0e914b4d 100644 --- a/docs/source/backend/index.rst +++ b/docs/source/backend/index.rst @@ -9,7 +9,7 @@ who wish to customize or extend the behaviour of the code generator in their app .. toctree:: :maxdepth: 1 - symbols + objects ast iteration_space translation diff --git a/docs/source/backend/symbols.rst b/docs/source/backend/objects.rst similarity index 80% rename from docs/source/backend/symbols.rst rename to docs/source/backend/objects.rst index 66c8c43ba..b0c3af6db 100644 --- a/docs/source/backend/symbols.rst +++ b/docs/source/backend/objects.rst @@ -8,5 +8,8 @@ Symbols, Constants and Arrays .. autoclass:: pystencils.backend.constants.PsConstant :members: +.. autoclass:: pystencils.backend.literals.PsLiteral + :members: + .. automodule:: pystencils.backend.arrays :members: diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 0666d9687..7bcf62b97 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -5,6 +5,7 @@ import operator from ..symbols import PsSymbol from ..constants import PsConstant +from ..literals import PsLiteral from ..arrays import PsLinearizedArray, PsArrayBasePointer from ..functions import PsFunction from ...types import ( @@ -76,12 +77,19 @@ class PsExpression(PsAstNode, ABC): def make(obj: PsConstant) -> PsConstantExpr: pass + @overload + @staticmethod + def make(obj: PsLiteral) -> PsLiteralExpr: + pass + @staticmethod - def make(obj: PsSymbol | PsConstant) -> PsSymbolExpr | PsConstantExpr: + def make(obj: PsSymbol | PsConstant | PsLiteral) -> PsExpression: if isinstance(obj, PsSymbol): return PsSymbolExpr(obj) elif isinstance(obj, PsConstant): return PsConstantExpr(obj) + elif isinstance(obj, PsLiteral): + return PsLiteralExpr(obj) else: raise ValueError(f"Cannot make expression out of {obj}") @@ -150,6 +158,34 @@ class PsConstantExpr(PsLeafMixIn, PsExpression): def __repr__(self) -> str: return f"PsConstantExpr({repr(self._constant)})" + + +class PsLiteralExpr(PsLeafMixIn, PsExpression): + __match_args__ = ("literal",) + + def __init__(self, literal: PsLiteral): + super().__init__(literal.dtype) + self._literal = literal + + @property + def literal(self) -> PsLiteral: + return self._literal + + @literal.setter + def literal(self, lit: PsLiteral): + self._literal = lit + + def clone(self) -> PsLiteralExpr: + return PsLiteralExpr(self._literal) + + def structurally_equal(self, other: PsAstNode) -> bool: + if not isinstance(other, PsLiteralExpr): + return False + + return self._literal == other._literal + + def __repr__(self) -> str: + return f"PsLiteralExpr({repr(self._literal)})" class PsSubscript(PsLvalue, PsExpression): diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py index 588ac410a..f3d56c6c4 100644 --- a/src/pystencils/backend/emission.py +++ b/src/pystencils/backend/emission.py @@ -34,6 +34,7 @@ from .ast.expressions import ( PsSub, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsVectorArrayAccess, PsAnd, PsOr, @@ -245,6 +246,9 @@ class CAstPrinter: ) return dtype.create_literal(constant.value) + + case PsLiteralExpr(lit): + return lit.text case PsVectorArrayAccess(): raise EmissionError("Cannot print vectorized array accesses") diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py index 28c9788d6..e420deaa6 100644 --- a/src/pystencils/backend/functions.py +++ b/src/pystencils/backend/functions.py @@ -82,7 +82,7 @@ class CFunction(PsFunction): return_type: The function's return type """ - __match_args__ = ("name", "argument_types", "return_type") + __match_args__ = ("name", "parameter_types", "return_type") @staticmethod def parse(obj) -> CFunction: @@ -108,17 +108,24 @@ class CFunction(PsFunction): def __init__(self, name: str, param_types: Sequence[PsType], return_type: PsType): super().__init__(name, len(param_types)) - self._arg_types = tuple(param_types) + self._param_types = tuple(param_types) self._return_type = return_type @property def parameter_types(self) -> tuple[PsType, ...]: - return self._arg_types + return self._param_types @property def return_type(self) -> PsType: return self._return_type + def __str__(self) -> str: + param_types = ", ".join(str(t) for t in self._param_types) + return f"{self._return_type} {self._name}({param_types})" + + def __repr__(self) -> str: + return f"CFunction({self._name}, {self._param_types}, {self._return_type})" + class PsMathFunction(PsFunction): """Homogenously typed mathematical functions.""" diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index d933004fc..95fa0e36c 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -39,6 +39,7 @@ from ..ast.expressions import ( PsLookup, PsSubscript, PsSymbolExpr, + PsLiteralExpr, PsRel, PsNeg, PsNot, @@ -158,6 +159,14 @@ class TypeContext: f" Constant type: {c.dtype}\n" f" Target type: {self._target_type}" ) + + case PsLiteralExpr(lit): + if not self._compatible(lit.dtype): + raise TypificationError( + f"Type mismatch at literal {lit}: Literal type did not match the context's target type\n" + f" Literal type: {lit.dtype}\n" + f" Target type: {self._target_type}" + ) case PsSymbolExpr(symb): assert symb.dtype is not None @@ -356,6 +365,9 @@ class Typifier: else: tc.infer_dtype(expr) + case PsLiteralExpr(lit): + tc.apply_dtype(lit.dtype, expr) + case PsArrayAccess(bptr, idx): tc.apply_dtype(bptr.array.element_type, expr) diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py new file mode 100644 index 000000000..dc7504f52 --- /dev/null +++ b/src/pystencils/backend/literals.py @@ -0,0 +1,43 @@ +from __future__ import annotations +from ..types import PsType, constify + + +class PsLiteral: + """Representation of literal code. + + Instances of this class represent code literals inside the AST. + These literals are not to be confused with C literals; the name `Literal` refers to the fact that + the code generator takes them "literally", printing them as they are. + + Each literal has to be annotated with a type, and is considered constant within the scope of a kernel. + Instances of `PsLiteral` are immutable. + """ + + __match_args__ = ("text", "dtype") + + def __init__(self, text: str, dtype: PsType) -> None: + self._text = text + self._dtype = constify(dtype) + + @property + def text(self) -> str: + return self._text + + @property + def dtype(self) -> PsType: + return self._dtype + + def __str__(self) -> str: + return f"{self._text}: {self._dtype}" + + def __repr__(self) -> str: + return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, PsLiteral): + return False + + return self._text == other._text and self._dtype == other._dtype + + def __hash__(self) -> int: + return hash((PsLiteral, self._text, self._dtype)) diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py index ef4861aa7..3ef64c73c 100644 --- a/src/pystencils/backend/platforms/generic_gpu.py +++ b/src/pystencils/backend/platforms/generic_gpu.py @@ -11,26 +11,26 @@ from ..kernelcreation.iteration_space import ( from ..ast.structural import PsBlock, PsConditional from ..ast.expressions import ( PsExpression, - PsSymbolExpr, + PsLiteralExpr, PsAdd, ) from ..ast.expressions import PsLt, PsAnd from ...types import PsSignedIntegerType -from ..symbols import PsSymbol +from ..literals import PsLiteral int32 = PsSignedIntegerType(width=32, const=False) BLOCK_IDX = [ - PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z") ] THREAD_IDX = [ - PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z") ] BLOCK_DIM = [ - PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z") ] GRID_DIM = [ - PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") + PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z") ] diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 3b9f0b700..ddfa33f08 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -9,6 +9,7 @@ from ..ast.expressions import ( PsExpression, PsConstantExpr, PsSymbolExpr, + PsLiteralExpr, PsBinOp, PsAdd, PsSub, @@ -159,8 +160,8 @@ class EliminateConstants: Returns: (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant """ - # Return constants as they are - if isinstance(expr, PsConstantExpr): + # Return constants and literals as they are + if isinstance(expr, (PsConstantExpr, PsLiteralExpr)): return expr, True # Shortcut symbols @@ -317,7 +318,7 @@ class EliminateConstants: # If required, extract constant subexpressions if self._extract_constant_exprs: for i, (child, is_const) in enumerate(subtree_results): - if is_const and not isinstance(child, PsConstantExpr): + if is_const and not isinstance(child, (PsConstantExpr, PsLiteralExpr)): replacement = ecc.extract_expression(child) expr.set_child(i, replacement) diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py index 5824239e4..cb9c9e920 100644 --- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py +++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py @@ -7,6 +7,7 @@ from ..ast.expressions import ( PsExpression, PsSymbolExpr, PsConstantExpr, + PsLiteralExpr, PsCall, PsDeref, PsSubscript, @@ -40,7 +41,7 @@ class HoistContext: symbol in self.invariant_symbols ) - case PsConstantExpr(): + case PsConstantExpr() | PsLiteralExpr(): return True case PsCall(func): -- GitLab