From fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Sun, 21 Apr 2024 16:31:38 +0200
Subject: [PATCH] Introduce PsLiteral and PsLiteralExpr.

---
 docs/source/backend/index.rst                 |  2 +-
 .../backend/{symbols.rst => objects.rst}      |  3 ++
 src/pystencils/backend/ast/expressions.py     | 38 +++++++++++++++-
 src/pystencils/backend/emission.py            |  4 ++
 src/pystencils/backend/functions.py           | 13 ++++--
 .../backend/kernelcreation/typification.py    | 12 ++++++
 src/pystencils/backend/literals.py            | 43 +++++++++++++++++++
 .../backend/platforms/generic_gpu.py          | 12 +++---
 .../transformations/eliminate_constants.py    |  7 +--
 .../hoist_loop_invariant_decls.py             |  3 +-
 10 files changed, 122 insertions(+), 15 deletions(-)
 rename docs/source/backend/{symbols.rst => objects.rst} (80%)
 create mode 100644 src/pystencils/backend/literals.py

diff --git a/docs/source/backend/index.rst b/docs/source/backend/index.rst
index df194bde9..e0e914b4d 100644
--- a/docs/source/backend/index.rst
+++ b/docs/source/backend/index.rst
@@ -9,7 +9,7 @@ who wish to customize or extend the behaviour of the code generator in their app
 .. toctree::
     :maxdepth: 1
     
-    symbols
+    objects
     ast
     iteration_space
     translation
diff --git a/docs/source/backend/symbols.rst b/docs/source/backend/objects.rst
similarity index 80%
rename from docs/source/backend/symbols.rst
rename to docs/source/backend/objects.rst
index 66c8c43ba..b0c3af6db 100644
--- a/docs/source/backend/symbols.rst
+++ b/docs/source/backend/objects.rst
@@ -8,5 +8,8 @@ Symbols, Constants and Arrays
 .. autoclass:: pystencils.backend.constants.PsConstant
     :members:
 
+.. autoclass:: pystencils.backend.literals.PsLiteral
+    :members:
+
 .. automodule:: pystencils.backend.arrays
     :members:
diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py
index 0666d9687..7bcf62b97 100644
--- a/src/pystencils/backend/ast/expressions.py
+++ b/src/pystencils/backend/ast/expressions.py
@@ -5,6 +5,7 @@ import operator
 
 from ..symbols import PsSymbol
 from ..constants import PsConstant
+from ..literals import PsLiteral
 from ..arrays import PsLinearizedArray, PsArrayBasePointer
 from ..functions import PsFunction
 from ...types import (
@@ -76,12 +77,19 @@ class PsExpression(PsAstNode, ABC):
     def make(obj: PsConstant) -> PsConstantExpr:
         pass
 
+    @overload
+    @staticmethod
+    def make(obj: PsLiteral) -> PsLiteralExpr:
+        pass
+
     @staticmethod
-    def make(obj: PsSymbol | PsConstant) -> PsSymbolExpr | PsConstantExpr:
+    def make(obj: PsSymbol | PsConstant | PsLiteral) -> PsExpression:
         if isinstance(obj, PsSymbol):
             return PsSymbolExpr(obj)
         elif isinstance(obj, PsConstant):
             return PsConstantExpr(obj)
+        elif isinstance(obj, PsLiteral):
+            return PsLiteralExpr(obj)
         else:
             raise ValueError(f"Cannot make expression out of {obj}")
 
@@ -150,6 +158,34 @@ class PsConstantExpr(PsLeafMixIn, PsExpression):
 
     def __repr__(self) -> str:
         return f"PsConstantExpr({repr(self._constant)})"
+    
+
+class PsLiteralExpr(PsLeafMixIn, PsExpression):
+    __match_args__ = ("literal",)
+
+    def __init__(self, literal: PsLiteral):
+        super().__init__(literal.dtype)
+        self._literal = literal
+
+    @property
+    def literal(self) -> PsLiteral:
+        return self._literal
+
+    @literal.setter
+    def literal(self, lit: PsLiteral):
+        self._literal = lit
+
+    def clone(self) -> PsLiteralExpr:
+        return PsLiteralExpr(self._literal)
+    
+    def structurally_equal(self, other: PsAstNode) -> bool:
+        if not isinstance(other, PsLiteralExpr):
+            return False
+
+        return self._literal == other._literal
+
+    def __repr__(self) -> str:
+        return f"PsLiteralExpr({repr(self._literal)})"
 
 
 class PsSubscript(PsLvalue, PsExpression):
diff --git a/src/pystencils/backend/emission.py b/src/pystencils/backend/emission.py
index 588ac410a..f3d56c6c4 100644
--- a/src/pystencils/backend/emission.py
+++ b/src/pystencils/backend/emission.py
@@ -34,6 +34,7 @@ from .ast.expressions import (
     PsSub,
     PsSubscript,
     PsSymbolExpr,
+    PsLiteralExpr,
     PsVectorArrayAccess,
     PsAnd,
     PsOr,
@@ -245,6 +246,9 @@ class CAstPrinter:
                     )
 
                 return dtype.create_literal(constant.value)
+            
+            case PsLiteralExpr(lit):
+                return lit.text
 
             case PsVectorArrayAccess():
                 raise EmissionError("Cannot print vectorized array accesses")
diff --git a/src/pystencils/backend/functions.py b/src/pystencils/backend/functions.py
index 28c9788d6..e420deaa6 100644
--- a/src/pystencils/backend/functions.py
+++ b/src/pystencils/backend/functions.py
@@ -82,7 +82,7 @@ class CFunction(PsFunction):
         return_type: The function's return type
     """
 
-    __match_args__ = ("name", "argument_types", "return_type")
+    __match_args__ = ("name", "parameter_types", "return_type")
 
     @staticmethod
     def parse(obj) -> CFunction:
@@ -108,17 +108,24 @@ class CFunction(PsFunction):
     def __init__(self, name: str, param_types: Sequence[PsType], return_type: PsType):
         super().__init__(name, len(param_types))
 
-        self._arg_types = tuple(param_types)
+        self._param_types = tuple(param_types)
         self._return_type = return_type
 
     @property
     def parameter_types(self) -> tuple[PsType, ...]:
-        return self._arg_types
+        return self._param_types
 
     @property
     def return_type(self) -> PsType:
         return self._return_type
 
+    def __str__(self) -> str:
+        param_types = ", ".join(str(t) for t in self._param_types)
+        return f"{self._return_type} {self._name}({param_types})"
+
+    def __repr__(self) -> str:
+        return f"CFunction({self._name}, {self._param_types}, {self._return_type})"
+
 
 class PsMathFunction(PsFunction):
     """Homogenously typed mathematical functions."""
diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py
index d933004fc..95fa0e36c 100644
--- a/src/pystencils/backend/kernelcreation/typification.py
+++ b/src/pystencils/backend/kernelcreation/typification.py
@@ -39,6 +39,7 @@ from ..ast.expressions import (
     PsLookup,
     PsSubscript,
     PsSymbolExpr,
+    PsLiteralExpr,
     PsRel,
     PsNeg,
     PsNot,
@@ -158,6 +159,14 @@ class TypeContext:
                             f"  Constant type: {c.dtype}\n"
                             f"    Target type: {self._target_type}"
                         )
+                
+                case PsLiteralExpr(lit):
+                    if not self._compatible(lit.dtype):
+                        raise TypificationError(
+                            f"Type mismatch at literal {lit}: Literal type did not match the context's target type\n"
+                            f"   Literal type: {lit.dtype}\n"
+                            f"    Target type: {self._target_type}"
+                        )
 
                 case PsSymbolExpr(symb):
                     assert symb.dtype is not None
@@ -356,6 +365,9 @@ class Typifier:
                 else:
                     tc.infer_dtype(expr)
 
+            case PsLiteralExpr(lit):
+                tc.apply_dtype(lit.dtype, expr)
+
             case PsArrayAccess(bptr, idx):
                 tc.apply_dtype(bptr.array.element_type, expr)
 
diff --git a/src/pystencils/backend/literals.py b/src/pystencils/backend/literals.py
new file mode 100644
index 000000000..dc7504f52
--- /dev/null
+++ b/src/pystencils/backend/literals.py
@@ -0,0 +1,43 @@
+from __future__ import annotations
+from ..types import PsType, constify
+
+
+class PsLiteral:
+    """Representation of literal code.
+    
+    Instances of this class represent code literals inside the AST.
+    These literals are not to be confused with C literals; the name `Literal` refers to the fact that
+    the code generator takes them "literally", printing them as they are.
+
+    Each literal has to be annotated with a type, and is considered constant within the scope of a kernel.
+    Instances of `PsLiteral` are immutable.
+    """
+
+    __match_args__ = ("text", "dtype")
+
+    def __init__(self, text: str, dtype: PsType) -> None:
+        self._text = text
+        self._dtype = constify(dtype)
+
+    @property
+    def text(self) -> str:
+        return self._text
+    
+    @property
+    def dtype(self) -> PsType:
+        return self._dtype
+    
+    def __str__(self) -> str:
+        return f"{self._text}: {self._dtype}"
+    
+    def __repr__(self) -> str:
+        return f"PsLiteral({repr(self._text)}, {repr(self._dtype)})"
+    
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, PsLiteral):
+            return False
+        
+        return self._text == other._text and self._dtype == other._dtype
+    
+    def __hash__(self) -> int:
+        return hash((PsLiteral, self._text, self._dtype))
diff --git a/src/pystencils/backend/platforms/generic_gpu.py b/src/pystencils/backend/platforms/generic_gpu.py
index ef4861aa7..3ef64c73c 100644
--- a/src/pystencils/backend/platforms/generic_gpu.py
+++ b/src/pystencils/backend/platforms/generic_gpu.py
@@ -11,26 +11,26 @@ from ..kernelcreation.iteration_space import (
 from ..ast.structural import PsBlock, PsConditional
 from ..ast.expressions import (
     PsExpression,
-    PsSymbolExpr,
+    PsLiteralExpr,
     PsAdd,
 )
 from ..ast.expressions import PsLt, PsAnd
 from ...types import PsSignedIntegerType
-from ..symbols import PsSymbol
+from ..literals import PsLiteral
 
 int32 = PsSignedIntegerType(width=32, const=False)
 
 BLOCK_IDX = [
-    PsSymbolExpr(PsSymbol(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
+    PsLiteralExpr(PsLiteral(f"blockIdx.{coord}", int32)) for coord in ("x", "y", "z")
 ]
 THREAD_IDX = [
-    PsSymbolExpr(PsSymbol(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
+    PsLiteralExpr(PsLiteral(f"threadIdx.{coord}", int32)) for coord in ("x", "y", "z")
 ]
 BLOCK_DIM = [
-    PsSymbolExpr(PsSymbol(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
+    PsLiteralExpr(PsLiteral(f"blockDim.{coord}", int32)) for coord in ("x", "y", "z")
 ]
 GRID_DIM = [
-    PsSymbolExpr(PsSymbol(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
+    PsLiteralExpr(PsLiteral(f"gridDim.{coord}", int32)) for coord in ("x", "y", "z")
 ]
 
 
diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py
index 3b9f0b700..ddfa33f08 100644
--- a/src/pystencils/backend/transformations/eliminate_constants.py
+++ b/src/pystencils/backend/transformations/eliminate_constants.py
@@ -9,6 +9,7 @@ from ..ast.expressions import (
     PsExpression,
     PsConstantExpr,
     PsSymbolExpr,
+    PsLiteralExpr,
     PsBinOp,
     PsAdd,
     PsSub,
@@ -159,8 +160,8 @@ class EliminateConstants:
         Returns:
             (transformed_expr, is_const): The tranformed expression, and a flag indicating whether it is constant
         """
-        #   Return constants as they are
-        if isinstance(expr, PsConstantExpr):
+        #   Return constants and literals as they are
+        if isinstance(expr, (PsConstantExpr, PsLiteralExpr)):
             return expr, True
 
         #   Shortcut symbols
@@ -317,7 +318,7 @@ class EliminateConstants:
         #   If required, extract constant subexpressions
         if self._extract_constant_exprs:
             for i, (child, is_const) in enumerate(subtree_results):
-                if is_const and not isinstance(child, PsConstantExpr):
+                if is_const and not isinstance(child, (PsConstantExpr, PsLiteralExpr)):
                     replacement = ecc.extract_expression(child)
                     expr.set_child(i, replacement)
 
diff --git a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
index 5824239e4..cb9c9e920 100644
--- a/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
+++ b/src/pystencils/backend/transformations/hoist_loop_invariant_decls.py
@@ -7,6 +7,7 @@ from ..ast.expressions import (
     PsExpression,
     PsSymbolExpr,
     PsConstantExpr,
+    PsLiteralExpr,
     PsCall,
     PsDeref,
     PsSubscript,
@@ -40,7 +41,7 @@ class HoistContext:
                     symbol in self.invariant_symbols
                 )
 
-            case PsConstantExpr():
+            case PsConstantExpr() | PsLiteralExpr():
                 return True
 
             case PsCall(func):
-- 
GitLab