Skip to content
Snippets Groups Projects
Commit e31a9028 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'fhennig/fix-constant-elim' into 'v2.0-dev'

Fixes to Constant Elimination Pass

See merge request !427
parents 722fa3ce 474b5ab2
Branches
No related tags found
1 merge request!427Fixes to Constant Elimination Pass
Pipeline #70222 passed
......@@ -17,9 +17,21 @@ def emit_ir(ir: PsAstNode):
class IRAstPrinter(BasePrinter):
"""Print the IR AST as pseudo-code.
def __init__(self, indent_width=3):
This printer produces a complete pseudocode representation of a pystencils AST.
Other than the `CAstPrinter`, the `IRAstPrinter` is capable of emitting code for
each node defined in `ast <pystencils.backend.ast>`.
It is furthermore configurable w.r.t. the level of detail it should emit.
Args:
indent_width: Number of spaces with which to indent lines in each nested block.
annotate_constants: If ``True`` (the default), annotate all constant literals with their data type.
"""
def __init__(self, indent_width=3, annotate_constants: bool = True):
super().__init__(indent_width)
self._annotate_constants = annotate_constants
def visit(self, node: PsAstNode, pc: PrinterCtx) -> str:
match node:
......@@ -66,7 +78,10 @@ class IRAstPrinter(BasePrinter):
return f"{symb.name}: {self._type_str(symb.dtype)}"
def _constant_literal(self, constant: PsConstant) -> str:
if self._annotate_constants:
return f"[{constant.value}: {self._deconst_type_str(constant.dtype)}]"
else:
return str(constant.value)
def _type_str(self, dtype: PsType | None):
if dtype is None:
......
......@@ -45,7 +45,7 @@ from ...types import (
PsBoolType,
PsScalarType,
PsVectorType,
PsTypeError,
constify
)
......@@ -57,9 +57,9 @@ class ECContext:
self._ctx = ctx
self._extracted_constants: dict[AstEqWrapper, PsSymbol] = dict()
from ..emission import CAstPrinter
from ..emission import IRAstPrinter
self._printer = CAstPrinter(0)
self._printer = IRAstPrinter(indent_width=0, annotate_constants=False)
@property
def extractions(self) -> Iterable[tuple[PsSymbol, PsExpression]]:
......@@ -89,10 +89,7 @@ class ECContext:
if expr_wrapped not in self._extracted_constants:
symb_name = self._get_symb_name(expr)
try:
symb = self._ctx.get_symbol(symb_name, dtype)
except PsTypeError:
symb = self._ctx.get_symbol(f"{symb_name}_{dtype.c_string()}", dtype)
symb = self._ctx.get_new_symbol(symb_name, constify(dtype))
self._extracted_constants[expr_wrapped] = symb
else:
......@@ -133,6 +130,10 @@ class EliminateConstants:
def __call__(self, node: PsExpression) -> PsExpression:
pass
@overload
def __call__(self, node: PsBlock) -> PsBlock:
pass
@overload
def __call__(self, node: PsAstNode) -> PsAstNode:
pass
......
......@@ -106,9 +106,13 @@ class PsArrayType(PsDereferencableType):
"""
def __init__(
self, element_type: PsType, shape: SupportsIndex | Sequence[SupportsIndex], const: bool = False
self,
element_type: PsType,
shape: SupportsIndex | Sequence[SupportsIndex],
const: bool = False,
):
from operator import index
if isinstance(shape, SupportsIndex):
shape = (index(shape),)
else:
......@@ -396,12 +400,13 @@ class PsVectorType(PsNumericType):
return np.dtype((self._scalar_type.numpy_dtype, (self._vector_entries,)))
def create_constant(self, value: Any) -> Any:
if (
isinstance(value, np.ndarray)
and value.dtype == self.scalar_type.numpy_dtype
and value.shape == (self._vector_entries,)
):
return value.copy()
if isinstance(value, np.ndarray):
if value.shape != (self._vector_entries,):
raise PsTypeError(
f"Cannot create constant of vector type {self} from array of shape {value.shape}"
)
return np.array([self._scalar_type.create_constant(v) for v in value])
element = self._scalar_type.create_constant(value)
return np.array(
......
from typing import Any
import pytest
import numpy as np
import sympy as sp
from pystencils.backend.kernelcreation import KernelCreationContext, Typifier
from pystencils import TypedSymbol, Assignment
from pystencils.backend.kernelcreation import (
KernelCreationContext,
Typifier,
AstFactory,
)
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.backend.ast.expressions import PsExpression, PsConstantExpr
from pystencils.backend.memory import PsSymbol
from pystencils.backend.constants import PsConstant
......@@ -17,15 +24,16 @@ from pystencils.backend.ast.expressions import (
PsTernary,
PsRem,
PsIntDiv,
PsCast
)
from pystencils.types.quick import Int, Fp, Bool
from pystencils.types import PsVectorType, create_numeric_type
from pystencils.types import PsVectorType, create_numeric_type, constify, create_type
class Exprs:
def __init__(self, mode: str):
self._mode = mode
self.mode = mode
if mode == "scalar":
self._itype = Int(32)
......@@ -49,7 +57,7 @@ class Exprs:
self.true = PsExpression.make(PsConstant(True, self._btype))
self.false = PsExpression.make(PsConstant(False, self._btype))
def __call__(self, val) -> Any:
def __call__(self, val) -> PsExpression:
match val:
case int():
return PsExpression.make(PsConstant(val, self._itype))
......@@ -311,3 +319,101 @@ def test_fold_vectors():
)
result = elim(expr)
assert result.structurally_equal(e(np.array([True, True, False, True])))
def test_fold_casts(exprs):
e = exprs
ctx = KernelCreationContext()
typify = Typifier(ctx)
elim = EliminateConstants(ctx, fold_floats=True)
target_type = create_type("float16")
if e.mode == "vector":
target_type = PsVectorType(target_type, 4)
expr = typify(PsCast(target_type, e(41.2)))
result = elim(expr)
assert isinstance(result, PsConstantExpr)
np.testing.assert_equal(result.constant.value, e(41.2).constant.value.astype("float16"))
def test_extract_constant_subexprs():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
x, y, z = sp.symbols("x, y, z")
q, w = TypedSymbol("q", "float32"), TypedSymbol("w", "float32")
block = PsBlock(
[
factory.parse_sympy(Assignment(x, sp.Rational(3, 2))),
factory.parse_sympy(Assignment(y, x + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5))),
factory.parse_sympy(Assignment(q, w + sp.Rational(7, 4))),
factory.parse_sympy(Assignment(z, y - sp.Rational(12, 5) + z * sp.sin(41))),
]
)
result = elim(block)
assert len(result.statements) == 9
c_symb = ctx.find_symbol("__c_3_0o2_0")
assert c_symb is None
c_symb = ctx.find_symbol("__c_7_0o4_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
c_symb = ctx.find_symbol("__c_s12_0o5_0")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
# Make sure symbol was duplicated
c_symb = ctx.find_symbol("__c_7_0o4_0__0")
assert c_symb is not None
assert c_symb.dtype == constify(create_numeric_type("float32"))
c_symb = ctx.find_symbol("__c_sin_41_0_")
assert c_symb is not None
assert c_symb.dtype == constify(ctx.default_dtype)
def test_extract_vector_constants():
ctx = KernelCreationContext(default_dtype=create_numeric_type("float64"))
factory = AstFactory(ctx)
typify = Typifier(ctx)
elim = EliminateConstants(ctx, extract_constant_exprs=True)
vtype = PsVectorType(ctx.default_dtype, 8)
x, y, z = TypedSymbol("x", vtype), TypedSymbol("y", vtype), TypedSymbol("z", vtype)
num = typify.typify_expression(
PsExpression.make(
PsConstant(np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
),
vtype,
)[0]
denom = typify.typify_expression(PsExpression.make(PsConstant(3.0)), vtype)[0]
vconstant = num / denom
block = PsBlock(
[
factory.parse_sympy(Assignment(x, y - sp.Rational(3, 2))),
PsDeclaration(
factory.parse_sympy(z),
typify(factory.parse_sympy(y) + num / denom),
),
]
)
result = elim(block)
assert len(result.statements) == 4
assert isinstance(result.statements[1], PsDeclaration)
assert result.statements[1].rhs.structurally_equal(vconstant)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment