From 474b5ab27695a50d1e4eb6c6e21f5ec902ab9cdf Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 12 Nov 2024 08:48:56 +0100 Subject: [PATCH] Fixes to Constant Elimination Pass --- src/pystencils/backend/emission/ir_printer.py | 21 +++- .../transformations/eliminate_constants.py | 15 +-- src/pystencils/types/types.py | 29 +++-- .../test_constant_elimination.py | 114 +++++++++++++++++- 4 files changed, 153 insertions(+), 26 deletions(-) diff --git a/src/pystencils/backend/emission/ir_printer.py b/src/pystencils/backend/emission/ir_printer.py index 0b4a18bd5..4986f1a7f 100644 --- a/src/pystencils/backend/emission/ir_printer.py +++ b/src/pystencils/backend/emission/ir_printer.py @@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode): class IRAstPrinter(BasePrinter): - - def __init__(self, indent_width=3): + """Print the IR AST as pseudo-code. + + This printer produces a complete pseudocode representation of a pystencils AST. + Other than the `CAstPrinter`, the `IRAstPrinter` is capable of emitting code for + each node defined in `ast <pystencils.backend.ast>`. + It is furthermore configurable w.r.t. the level of detail it should emit. + + Args: + indent_width: Number of spaces with which to indent lines in each nested block. + annotate_constants: If ``True`` (the default), annotate all constant literals with their data type. + """ + + def __init__(self, indent_width=3, annotate_constants: bool = True): super().__init__(indent_width) + self._annotate_constants = annotate_constants def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: match node: @@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter): return f"{symb.name}: {self._type_str(symb.dtype)}" def _constant_literal(self, constant: PsConstant) -> str: - return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]" + if self._annotate_constants: + return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]" + else: + return str(constant.value) def _type_str(self, dtype: PsType | None): if dtype is None: diff --git a/src/pystencils/backend/transformations/eliminate_constants.py b/src/pystencils/backend/transformations/eliminate_constants.py index 961f4a04a..d6f13bf07 100644 --- a/src/pystencils/backend/transformations/eliminate_constants.py +++ b/src/pystencils/backend/transformations/eliminate_constants.py @@ -45,7 +45,7 @@ from ...types import ( PsBoolType, PsScalarType, PsVectorType, - PsTypeError, + constify ) @@ -57,9 +57,9 @@ class ECContext: self._ctx = ctx self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() - from ..emission import CAstPrinter + from ..emission import IRAstPrinter - self._printer = CAstPrinter(0) + self._printer = IRAstPrinter(indent_width=0, annotate_constants=False) @property def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]: @@ -89,10 +89,7 @@ class ECContext: if expr_wrapped not in self._extracted_constants: symb_name = self._get_symb_name(expr) - try: - symb = self._ctx.get_symbol(symb_name, dtype) - except PsTypeError: - symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype) + symb = self._ctx.get_new_symbol(symb_name, constify(dtype)) self._extracted_constants[expr_wrapped] = symb else: @@ -133,6 +130,10 @@ class EliminateConstants: def __call__(self, node: PsExpression) -> PsExpression: pass + @overload + def __call__(self, node: PsBlock) -> PsBlock: + pass + @overload def __call__(self, node: PsAstNode) -> PsAstNode: pass diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index 6e4f65b85..ae751992d 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -100,15 +100,19 @@ class PsPointerType(PsDereferencableType): class PsArrayType(PsDereferencableType): """Multidimensional array of fixed shape. - + The element type of an array is never const; only the array itself can be. If ``element_type`` is const, its constness will be removed. """ def __init__( - self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False + self, + element_type: PsType, + shape: SupportsIndex | Sequence[SupportsIndex], + const: bool = False, ): from operator import index + if isinstance(shape, SupportsIndex): shape = (index(shape),) else: @@ -116,10 +120,10 @@ class PsArrayType(PsDereferencableType): if not shape or any(s <= 0 for s in shape): raise ValueError(f"Invalid array shape: {shape}") - + if isinstance(element_type, PsArrayType): raise ValueError("Element type of array cannot be another array.") - + element_type = deconstify(element_type) self._shape = shape @@ -137,7 +141,7 @@ class PsArrayType(PsDereferencableType): def shape(self) -> tuple[int, ...]: """Shape of this array""" return self._shape - + @property def dim(self) -> int: """Dimensionality of this array""" @@ -396,12 +400,13 @@ class PsVectorType(PsNumericType): return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,))) def create_constant(self, value: Any) -> Any: - if ( - isinstance(value, np.ndarray) - and value.dtype == self.scalar_type.numpy_dtype - and value.shape == (self._vector_entries,) - ): - return value.copy() + if isinstance(value, np.ndarray): + if value.shape != (self._vector_entries,): + raise PsTypeError( + f"Cannot create constant of vector type {self} from array of shape {value.shape}" + ) + + return np.array([self._scalar_type.create_constant(v) for v in value]) element = self._scalar_type.create_constant(value) return np.array( @@ -552,7 +557,7 @@ class PsIntegerType(PsScalarType, ABC): def c_string(self) -> str: return f"{self._const_string()}{self._str_without_const()}_t" - + def __str__(self) -> str: return f"{self._const_string()}{self._str_without_const()}" diff --git a/tests/nbackend/transformations/test_constant_elimination.py b/tests/nbackend/transformations/test_constant_elimination.py index 00df4a8a9..dd0ccf417 100644 --- a/tests/nbackend/transformations/test_constant_elimination.py +++ b/tests/nbackend/transformations/test_constant_elimination.py @@ -1,8 +1,15 @@ from typing import Any import pytest import numpy as np +import sympy as sp -from pystencils.backend.kernelcreation import KernelCreationContext, Typifier +from pystencils import TypedSymbol, Assignment +from pystencils.backend.kernelcreation import ( + KernelCreationContext, + Typifier, + AstFactory, +) +from pystencils.backend.ast.structural import PsBlock, PsDeclaration from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr from pystencils.backend.memory import PsSymbol from pystencils.backend.constants import PsConstant @@ -17,15 +24,16 @@ from pystencils.backend.ast.expressions import ( PsTernary, PsRem, PsIntDiv, + PsCast ) from pystencils.types.quick import Int, Fp, Bool -from pystencils.types import PsVectorType, create_numeric_type +from pystencils.types import PsVectorType, create_numeric_type, constify, create_type class Exprs: def __init__(self, mode: str): - self._mode = mode + self.mode = mode if mode == "scalar": self._itype = Int(32) @@ -49,7 +57,7 @@ class Exprs: self.true = PsExpression.make(PsConstant(True, self._btype)) self.false = PsExpression.make(PsConstant(False, self._btype)) - def __call__(self, val) -> Any: + def __call__(self, val) -> PsExpression: match val: case int(): return PsExpression.make(PsConstant(val, self._itype)) @@ -311,3 +319,101 @@ def test_fold_vectors(): ) result = elim(expr) assert result.structurally_equal(e(np.array([True, True, False, True]))) + + +def test_fold_casts(exprs): + e = exprs + + ctx = KernelCreationContext() + typify = Typifier(ctx) + elim = EliminateConstants(ctx, fold_floats=True) + + target_type = create_type("float16") + if e.mode == "vector": + target_type = PsVectorType(target_type, 4) + + expr = typify(PsCast(target_type, e(41.2))) + result = elim(expr) + + assert isinstance(result, PsConstantExpr) + np.testing.assert_equal(result.constant.value, e(41.2).constant.value.astype("float16")) + + +def test_extract_constant_subexprs(): + ctx = KernelCreationContext(default_dtype=create_numeric_type("float64")) + factory = AstFactory(ctx) + elim = EliminateConstants(ctx, extract_constant_exprs=True) + + x, y, z = sp.symbols("x, y, z") + q, w = TypedSymbol("q", "float32"), TypedSymbol("w", "float32") + + block = PsBlock( + [ + factory.parse_sympy(Assignment(x, sp.Rational(3, 2))), + factory.parse_sympy(Assignment(y, x + sp.Rational(7, 4))), + factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5))), + factory.parse_sympy(Assignment(q, w + sp.Rational(7, 4))), + factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5) + z * sp.sin(41))), + ] + ) + + result = elim(block) + + assert len(result.statements) == 9 + + c_symb = ctx.find_symbol("__c_3_0o2_0") + assert c_symb is None + + c_symb = ctx.find_symbol("__c_7_0o4_0") + assert c_symb is not None + assert c_symb.dtype == constify(ctx.default_dtype) + + c_symb = ctx.find_symbol("__c_s12_0o5_0") + assert c_symb is not None + assert c_symb.dtype == constify(ctx.default_dtype) + + # Make sure symbol was duplicated + c_symb = ctx.find_symbol("__c_7_0o4_0__0") + assert c_symb is not None + assert c_symb.dtype == constify(create_numeric_type("float32")) + + c_symb = ctx.find_symbol("__c_sin_41_0_") + assert c_symb is not None + assert c_symb.dtype == constify(ctx.default_dtype) + + +def test_extract_vector_constants(): + ctx = KernelCreationContext(default_dtype=create_numeric_type("float64")) + factory = AstFactory(ctx) + typify = Typifier(ctx) + elim = EliminateConstants(ctx, extract_constant_exprs=True) + + vtype = PsVectorType(ctx.default_dtype, 8) + x, y, z = TypedSymbol("x", vtype), TypedSymbol("y", vtype), TypedSymbol("z", vtype) + + num = typify.typify_expression( + PsExpression.make( + PsConstant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])) + ), + vtype, + )[0] + + denom = typify.typify_expression(PsExpression.make(PsConstant(3.0)), vtype)[0] + + vconstant = num / denom + + block = PsBlock( + [ + factory.parse_sympy(Assignment(x, y - sp.Rational(3, 2))), + PsDeclaration( + factory.parse_sympy(z), + typify(factory.parse_sympy(y) + num / denom), + ), + ] + ) + + result = elim(block) + + assert len(result.statements) == 4 + assert isinstance(result.statements[1], PsDeclaration) + assert result.statements[1].rhs.structurally_equal(vconstant) -- GitLab