Skip to content
Snippets Groups Projects
Commit 2c20c8de authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Fixes to constant elimination

 - Add test cases for constant extraction
 - Use `get_new_symbol` to create extracted symbols
 - Create extracted constant symbols with `const` data types
 - Add test case for folding of casts
 - Add config options for IRPrinter
parent 778222bf
No related branches found
No related tags found
1 merge request!427Fixes to Constant Elimination Pass
Pipeline #70181 passed
...@@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode): ...@@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode):
class IRAstPrinter(BasePrinter): class IRAstPrinter(BasePrinter):
"""Print the IR AST as pseudo-code.
def __init__(self, indent_width=3):
This printer produces a complete pseudocode representation of a pystencils AST.
Other than the `CAstPrinter`, the `IRAstPrinter` is capable of emitting code for
each node defined in `ast <pystencils.backend.ast>`.
It is furthermore configurable w.r.t. the level of detail it should emit.
Args:
indent_width: Number of spaces with which to indent lines in each nested block.
annotate_constants: If ``True`` (the default), annotate all constant literals with their data type.
"""
def __init__(self, indent_width=3, annotate_constants: bool = True):
super().__init__(indent_width) super().__init__(indent_width)
self._annotate_constants = annotate_constants
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str: def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node: match node:
...@@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter): ...@@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter):
return f"{symb.name}: {self._type_str(symb.dtype)}" return f"{symb.name}: {self._type_str(symb.dtype)}"
def _constant_literal(self, constant: PsConstant) -> str: def _constant_literal(self, constant: PsConstant) -> str:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]" if self._annotate_constants:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]"
else:
return str(constant.value)
def _type_str(self, dtype: PsType | None): def _type_str(self, dtype: PsType | None):
if dtype is None: if dtype is None:
......
...@@ -45,7 +45,7 @@ from ...types import ( ...@@ -45,7 +45,7 @@ from ...types import (
PsBoolType, PsBoolType,
PsScalarType, PsScalarType,
PsVectorType, PsVectorType,
PsTypeError, constify
) )
...@@ -57,9 +57,9 @@ class ECContext: ...@@ -57,9 +57,9 @@ class ECContext:
self._ctx = ctx self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict() self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
from ..emission import CAstPrinter from ..emission import IRAstPrinter
self._printer = CAstPrinter(0) self._printer = IRAstPrinter(indent_width=0, annotate_constants=False)
@property @property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]: def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
...@@ -89,10 +89,7 @@ class ECContext: ...@@ -89,10 +89,7 @@ class ECContext:
if expr_wrapped not in self._extracted_constants: if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr) symb_name = self._get_symb_name(expr)
try: symb = self._ctx.get_new_symbol(symb_name, constify(dtype))
symb = self._ctx.get_symbol(symb_name, dtype)
except PsTypeError:
symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
self._extracted_constants[expr_wrapped] = symb self._extracted_constants[expr_wrapped] = symb
else: else:
...@@ -133,6 +130,10 @@ class EliminateConstants: ...@@ -133,6 +130,10 @@ class EliminateConstants:
def __call__(self, node: PsExpression) -> PsExpression: def __call__(self, node: PsExpression) -> PsExpression:
pass pass
@overload
def __call__(self, node: PsBlock) -> PsBlock:
pass
@overload @overload
def __call__(self, node: PsAstNode) -> PsAstNode: def __call__(self, node: PsAstNode) -> PsAstNode:
pass pass
......
...@@ -100,15 +100,19 @@ class PsPointerType(PsDereferencableType): ...@@ -100,15 +100,19 @@ class PsPointerType(PsDereferencableType):
class PsArrayType(PsDereferencableType): class PsArrayType(PsDereferencableType):
"""Multidimensional array of fixed shape. """Multidimensional array of fixed shape.
The element type of an array is never const; only the array itself can be. The element type of an array is never const; only the array itself can be.
If ``element_type`` is const, its constness will be removed. If ``element_type`` is const, its constness will be removed.
""" """
def __init__( def __init__(
self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False self,
element_type: PsType,
shape: SupportsIndex | Sequence[SupportsIndex],
const: bool = False,
): ):
from operator import index from operator import index
if isinstance(shape, SupportsIndex): if isinstance(shape, SupportsIndex):
shape = (index(shape),) shape = (index(shape),)
else: else:
...@@ -116,10 +120,10 @@ class PsArrayType(PsDereferencableType): ...@@ -116,10 +120,10 @@ class PsArrayType(PsDereferencableType):
if not shape or any(s <= 0 for s in shape): if not shape or any(s <= 0 for s in shape):
raise ValueError(f"Invalid array shape: {shape}") raise ValueError(f"Invalid array shape: {shape}")
if isinstance(element_type, PsArrayType): if isinstance(element_type, PsArrayType):
raise ValueError("Element type of array cannot be another array.") raise ValueError("Element type of array cannot be another array.")
element_type = deconstify(element_type) element_type = deconstify(element_type)
self._shape = shape self._shape = shape
...@@ -137,7 +141,7 @@ class PsArrayType(PsDereferencableType): ...@@ -137,7 +141,7 @@ class PsArrayType(PsDereferencableType):
def shape(self) -> tuple[int, ...]: def shape(self) -> tuple[int, ...]:
"""Shape of this array""" """Shape of this array"""
return self._shape return self._shape
@property @property
def dim(self) -> int: def dim(self) -> int:
"""Dimensionality of this array""" """Dimensionality of this array"""
...@@ -396,12 +400,13 @@ class PsVectorType(PsNumericType): ...@@ -396,12 +400,13 @@ class PsVectorType(PsNumericType):
return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,))) return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,)))
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
if ( if isinstance(value, np.ndarray):
isinstance(value, np.ndarray) if value.shape != (self._vector_entries,):
and value.dtype == self.scalar_type.numpy_dtype raise PsTypeError(
and value.shape == (self._vector_entries,) f"Cannot create constant of vector type {self} from array of shape {value.shape}"
): )
return value.copy()
return np.array([self._scalar_type.create_constant(v) for v in value])
element = self._scalar_type.create_constant(value) element = self._scalar_type.create_constant(value)
return np.array( return np.array(
...@@ -552,7 +557,7 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -552,7 +557,7 @@ class PsIntegerType(PsScalarType, ABC):
def c_string(self) -> str: def c_string(self) -> str:
return f"{self._const_string()}{self._str_without_const()}_t" return f"{self._const_string()}{self._str_without_const()}_t"
def __str__(self) -> str: def __str__(self) -> str:
return f"{self._const_string()}{self._str_without_const()}" return f"{self._const_string()}{self._str_without_const()}"
......
from typing import Any from typing import Any
import pytest import pytest
import numpy as np import numpy as np
import sympy as sp
from pystencils.backend.kernelcreation import KernelCreationContext, Typifier from pystencils import TypedSymbol, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
Typifier,
AstFactory,
)
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
from pystencils.backend.memory import PsSymbol from pystencils.backend.memory import PsSymbol
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
...@@ -17,15 +24,16 @@ from pystencils.backend.ast.expressions import ( ...@@ -17,15 +24,16 @@ from pystencils.backend.ast.expressions import (
PsTernary, PsTernary,
PsRem, PsRem,
PsIntDiv, PsIntDiv,
PsCast
) )
from pystencils.types.quick import Int, Fp, Bool from pystencils.types.quick import Int, Fp, Bool
from pystencils.types import PsVectorType, create_numeric_type from pystencils.types import PsVectorType, create_numeric_type, constify, create_type
class Exprs: class Exprs:
def __init__(self, mode: str): def __init__(self, mode: str):
self._mode = mode self.mode = mode
if mode == "scalar": if mode == "scalar":
self._itype = Int(32) self._itype = Int(32)
...@@ -49,7 +57,7 @@ class Exprs: ...@@ -49,7 +57,7 @@ class Exprs:
self.true = PsExpression.make(PsConstant(True, self._btype)) self.true = PsExpression.make(PsConstant(True, self._btype))
self.false = PsExpression.make(PsConstant(False, self._btype)) self.false = PsExpression.make(PsConstant(False, self._btype))
def __call__(self, val) -> Any: def __call__(self, val) -> PsExpression:
match val: match val:
case int(): case int():
return PsExpression.make(PsConstant(val, self._itype)) return PsExpression.make(PsConstant(val, self._itype))
...@@ -311,3 +319,101 @@ def test_fold_vectors(): ...@@ -311,3 +319,101 @@ def test_fold_vectors():
) )
result = elim(expr) result = elim(expr)
assert result.structurally_equal(e(np.array([True, True, False, True]))) assert result.structurally_equal(e(np.array([True, True, False, True])))
def test_fold_casts(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx, fold_floats=True)
target_type = create_type("float16")
if e.mode == "vector":
target_type = PsVectorType(target_type, 4)
expr = typify(PsCast(target_type, e(41.2)))
result = elim(expr)
assert isinstance(result, PsConstantExpr)
np.testing.assert_equal(result.constant.value, e(41.2).constant.value.astype("float16"))
def test_extract_constant_subexprs():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
x, y, z = sp.symbols("x, y, z")
q, w = TypedSymbol("q", "float32"), TypedSymbol("w", "float32")
block = PsBlock(
[
factory.parse_sympy(Assignment(x, sp.Rational(3, 2))),
factory.parse_sympy(Assignment(y, x + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5))),
factory.parse_sympy(Assignment(q, w + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5) + z * sp.sin(41))),
]
)
result = elim(block)
assert len(result.statements) == 9
c_symb = ctx.find_symbol("__c_3_0o2_0")
assert c_symb is None
c_symb = ctx.find_symbol("__c_7_0o4_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
c_symb = ctx.find_symbol("__c_s12_0o5_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
# Make sure symbol was duplicated
c_symb = ctx.find_symbol("__c_7_0o4_0__0")
assert c_symb is not None
assert c_symb.dtype == constify(create_numeric_type("float32"))
c_symb = ctx.find_symbol("__c_sin_41_0_")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
def test_extract_vector_constants():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
vtype = PsVectorType(ctx.default_dtype, 8)
x, y, z = TypedSymbol("x", vtype), TypedSymbol("y", vtype), TypedSymbol("z", vtype)
num = typify.typify_expression(
PsExpression.make(
PsConstant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
),
vtype,
)[0]
denom = typify.typify_expression(PsExpression.make(PsConstant(3.0)), vtype)[0]
vconstant = num / denom
block = PsBlock(
[
factory.parse_sympy(Assignment(x, y - sp.Rational(3, 2))),
PsDeclaration(
factory.parse_sympy(z),
typify(factory.parse_sympy(y) + num / denom),
),
]
)
result = elim(block)
assert len(result.statements) == 4
assert isinstance(result.statements[1], PsDeclaration)
assert result.statements[1].rhs.structurally_equal(vconstant)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment