Skip to content
Snippets Groups Projects
Commit 575322bc authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon: Committed by Frederik Hennig
Browse files

Improve freezing of additions

parent 05aa74d2
No related branches found
No related tags found
1 merge request!415Improve freezing of additions
......@@ -191,27 +191,19 @@ class FreezeExpressions:
def map_Add(self, expr: sp.Add) -> PsExpression:
# TODO: think about numerically sensible ways of freezing sums and products
signs: list[int] = []
for summand in expr.args:
if summand.is_negative:
signs.append(-1)
elif isinstance(summand, sp.Mul) and any(
factor.is_negative for factor in summand.args
):
signs.append(-1)
else:
signs.append(1)
frozen_expr = self.visit_expr(expr.args[0])
for sign, arg in zip(signs[1:], expr.args[1:]):
if sign == -1:
arg = -arg
for summand in expr.args[1:]:
if isinstance(summand, sp.Mul) and any(
factor == -1 for factor in summand.args
):
summand = -summand
op = sub
else:
op = add
frozen_expr = op(frozen_expr, self.visit_expr(arg))
frozen_expr = op(frozen_expr, self.visit_expr(summand))
return frozen_expr
......@@ -272,7 +264,7 @@ class FreezeExpressions:
def map_TypedSymbol(self, expr: TypedSymbol):
dtype = expr.dtype
match dtype:
case DynamicType.NUMERIC_TYPE:
dtype = self._ctx.default_dtype
......
import sympy as sp
import pytest
from pystencils import Assignment, fields, create_type, create_numeric_type, TypedSymbol, DynamicType
from pystencils import (
Assignment,
fields,
create_type,
create_numeric_type,
TypedSymbol,
DynamicType,
)
from pystencils.sympyextensions import CastFunc
from pystencils.backend.ast.structural import (
......@@ -30,6 +37,9 @@ from pystencils.backend.ast.expressions import (
PsCall,
PsCast,
PsConstantExpr,
PsAdd,
PsMul,
PsSub,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import PsMathFunction, MathFunctions
......@@ -277,11 +287,11 @@ def test_dynamic_types():
p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"]
expr = freeze(x + y)
assert ctx.get_symbol("x").dtype == ctx.default_dtype
assert ctx.get_symbol("y").dtype == ctx.default_dtype
expr = freeze(p - q)
expr = freeze(p - q)
assert ctx.get_symbol("p").dtype == ctx.index_dtype
assert ctx.get_symbol("q").dtype == ctx.index_dtype
......@@ -309,3 +319,29 @@ def test_cast_func():
expr = freeze(CastFunc(42, create_type("int16")))
assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16"))))
def test_add_sub():
ctx = KernelCreationContext()
freeze = FreezeExpressions(ctx)
x = sp.Symbol("x")
y = sp.Symbol("y", negative=True)
x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y"))
two = PsExpression.make(PsConstant(2))
minus_two = PsExpression.make(PsConstant(-2))
expr = freeze(x + y)
assert expr.structurally_equal(PsAdd(x2, y2))
expr = freeze(x - y)
assert expr.structurally_equal(PsSub(x2, y2))
expr = freeze(x + 2 * y)
assert expr.structurally_equal(PsAdd(x2, PsMul(two, y2)))
expr = freeze(x - 2 * y)
assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment