Skip to content
Snippets Groups Projects
Commit 94e18cf3 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some refactoring of PsTypedConstant and tests

parent fe808e6b
No related branches found
No related tags found
No related merge requests found
Pipeline #59880 failed
from __future__ import annotations
from functools import reduce
from typing import TypeAlias, Union, Any, Tuple, Callable
import operator
from typing import TypeAlias, Union, Any, Tuple
import pymbolic.primitives as pb
......@@ -117,6 +116,24 @@ PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
class PsTypedConstant:
"""Represents typed constants occuring in the pystencils AST.
Internal Representation of Constants
------------------------------------
Each `PsNumericType` acts as a factory for the code generator's internal representation of that type's
constants. The `PsTypedConstant` class embedds these into the expression trees.
Upon construction, this class's constructor attempts to interpret the given value in the given data type
by passing it to the data type's factory, which in turn may throw an exception if the value's type does
not match.
Operations and Constant Folding
-------------------------------
The `PsTypedConstant` class overrides the basic arithmetic operations for use during a constant folding pass.
Their implementations are very strict regarding types: No implicit conversions take place, and both operands
must always have the exact same type.
The only exception to this rule are the values `0`, `1`, and `-1`, which are promoted to `PsTypedConstant`
(pymbolic injects those at times).
A Note On Divisions
-------------------
......@@ -171,77 +188,74 @@ class PsTypedConstant:
def __repr__(self) -> str:
return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
@staticmethod
def _fix(v: Any, dtype: PsNumericType) -> PsTypedConstant:
if not isinstance(v, PsTypedConstant):
return PsTypedConstant(v, dtype)
else:
return v
@staticmethod
def _bin_op(
lhs: PsTypedConstant, rhs: PsTypedConstant, op: Callable[[Any, Any], Any]
) -> PsTypedConstant:
"""Backend for binary operators. Never call directly!"""
if lhs._dtype != rhs._dtype:
def _fix(self, v: Any) -> PsTypedConstant:
"""In binary operations, checks for type equality and, if necessary, promotes the values
`0`, `1` and `-1` to `PsTypedConstant`."""
if not isinstance(v, PsTypedConstant) and v in (0, 1, -1):
return PsTypedConstant(v, self._dtype)
elif v._dtype != self._dtype:
raise PsTypeError(
f"Incompatible operand types in constant folding: {lhs._dtype} and {rhs._dtype}"
f"Incompatible operand types in constant folding: {self._dtype} and {v._dtype}"
)
else:
return v
try:
return PsTypedConstant(op(lhs._value, rhs._value), lhs._dtype)
except PsTypeError:
def _rfix(self, v: Any) -> PsTypedConstant:
"""Same as `_fix`, but for use with the `r...` versions of the binary ops. Only changes the order of the
types in the exception string."""
if not isinstance(v, PsTypedConstant) and v in (0, 1, -1):
return PsTypedConstant(v, self._dtype)
elif v._dtype != self._dtype:
raise PsTypeError(
f"Invalid operation in constant folding: {op.__name__}( {repr(lhs)}, {repr(rhs)} )"
f"Incompatible operand types in constant folding: {v._dtype} and {self._dtype}"
)
else:
return v
def __add__(self, other: Any):
return PsTypedConstant._bin_op(
self, PsTypedConstant._fix(other, self._dtype), operator.add
)
return PsTypedConstant(self._value + self._fix(other)._value, self._dtype)
def __radd__(self, other: Any):
return PsTypedConstant._bin_op(
PsTypedConstant._fix(other, self._dtype), self, operator.add
)
return PsTypedConstant(self._rfix(other)._value + self._value, self._dtype)
def __mul__(self, other: Any):
return PsTypedConstant._bin_op(
self, PsTypedConstant._fix(other, self._dtype), operator.mul
)
return PsTypedConstant(self._value * self._fix(other)._value, self._dtype)
def __rmul__(self, other: Any):
return PsTypedConstant._bin_op(
PsTypedConstant._fix(other, self._dtype), self, operator.mul
)
return PsTypedConstant(self._rfix(other)._value * self._value, self._dtype)
def __sub__(self, other: Any):
return PsTypedConstant._bin_op(
self, PsTypedConstant._fix(other, self._dtype), operator.sub
)
return PsTypedConstant(self._value - self._fix(other)._value, self._dtype)
def __rsub__(self, other: Any):
return PsTypedConstant._bin_op(
PsTypedConstant._fix(other, self._dtype), self, operator.sub
)
return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
def __truediv__(self, other: Any):
other2 = PsTypedConstant._fix(other, self._dtype)
if self._dtype.is_float():
return PsTypedConstant._bin_op(self, other2, operator.truediv)
else:
return PsTypedConstant(self._value / self._fix(other)._value, self._dtype)
elif self._dtype.is_uint():
# For unsigned integers, `//` does the correct thing
return PsTypedConstant(self._value // self._fix(other)._value, self._dtype)
elif self._dtype.is_sint():
return NotImplemented # todo: C integer division
else:
return NotImplemented
def __rtruediv__(self, other: Any):
other2 = PsTypedConstant._fix(other, self._dtype)
if self._dtype.is_float():
return PsTypedConstant._bin_op(other2, self, operator.truediv)
else:
return PsTypedConstant(self._rfix(other)._value / self._value, self._dtype)
elif self._dtype.is_uint():
return PsTypedConstant(self._rfix(other)._value // self._value, self._dtype)
elif self._dtype.is_sint():
return NotImplemented # todo: C integer division
else:
return NotImplemented
def __mod__(self, other: Any):
return NotImplemented # todo: C integer division
if self._dtype.is_uint():
return PsTypedConstant(self._value % self._fix(other)._value, self._dtype)
else:
return NotImplemented # todo: C integer division
def __neg__(self):
return PsTypedConstant(-self._value, self._dtype)
......
......@@ -27,6 +27,14 @@ def test_constant_folding_int(width):
assert folder(expr) == PsTypedConstant(-53, SInt(width))
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_constant_folding_product(width):
"""
The pymbolic constant folder shows inconsistent behaviour when folding products.
This test both describes the required behaviour and serves as a reminder to fix it.
"""
folder = ConstantFoldingMapper()
expr = pb.Product(
(
PsTypedConstant(2, SInt(width)),
......
......@@ -32,7 +32,7 @@ def test_float_constants(width):
assert a - b == PsTypedConstant(31.5, Fp(width))
assert a / c == PsTypedConstant(16.0, Fp(width))
def test_illegal_ops():
# Cannot interpret negative numbers as unsigned types
with pytest.raises(PsTypeError):
......@@ -53,7 +53,16 @@ def test_illegal_ops():
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_integer_division(width):
def test_unsigned_integer_division(width):
a = PsTypedConstant(8, UInt(width))
b = PsTypedConstant(3, UInt(width))
assert a / b == PsTypedConstant(2, UInt(width))
assert a % b == PsTypedConstant(2, UInt(width))
@pytest.mark.parametrize("width", (8, 16, 32, 64))
def test_signed_integer_division(width):
a = PsTypedConstant(-5, SInt(width))
b = PsTypedConstant(2, SInt(width))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment