From 575322bc5a9dc5563f0f3b109f2dcd5ad1a55e24 Mon Sep 17 00:00:00 2001 From: Daniel Bauer <daniel.j.bauer@fau.de> Date: Tue, 1 Oct 2024 09:18:46 +0200 Subject: [PATCH] Improve freezing of additions --- .../backend/kernelcreation/freeze.py | 22 ++++------ tests/nbackend/kernelcreation/test_freeze.py | 42 +++++++++++++++++-- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index b626ffb65..fa936e506 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -191,27 +191,19 @@ class FreezeExpressions: def map_Add(self, expr: sp.Add) -> PsExpression: # TODO: think about numerically sensible ways of freezing sums and products - signs: list[int] = [] - for summand in expr.args: - if summand.is_negative: - signs.append(-1) - elif isinstance(summand, sp.Mul) and any( - factor.is_negative for factor in summand.args - ): - signs.append(-1) - else: - signs.append(1) frozen_expr = self.visit_expr(expr.args[0]) - for sign, arg in zip(signs[1:], expr.args[1:]): - if sign == -1: - arg = -arg + for summand in expr.args[1:]: + if isinstance(summand, sp.Mul) and any( + factor == -1 for factor in summand.args + ): + summand = -summand op = sub else: op = add - frozen_expr = op(frozen_expr, self.visit_expr(arg)) + frozen_expr = op(frozen_expr, self.visit_expr(summand)) return frozen_expr @@ -272,7 +264,7 @@ class FreezeExpressions: def map_TypedSymbol(self, expr: TypedSymbol): dtype = expr.dtype - + match dtype: case DynamicType.NUMERIC_TYPE: dtype = self._ctx.default_dtype diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 9467bdd8e..2761d9b3f 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -1,7 +1,14 @@ import sympy as sp import pytest -from pystencils import Assignment, fields, create_type, create_numeric_type, TypedSymbol, DynamicType +from pystencils import ( + Assignment, + fields, + create_type, + create_numeric_type, + TypedSymbol, + DynamicType, +) from pystencils.sympyextensions import CastFunc from pystencils.backend.ast.structural import ( @@ -30,6 +37,9 @@ from pystencils.backend.ast.expressions import ( PsCall, PsCast, PsConstantExpr, + PsAdd, + PsMul, + PsSub, ) from pystencils.backend.constants import PsConstant from pystencils.backend.functions import PsMathFunction, MathFunctions @@ -277,11 +287,11 @@ def test_dynamic_types(): p, q = [TypedSymbol(n, DynamicType.INDEX_TYPE) for n in "pq"] expr = freeze(x + y) - + assert ctx.get_symbol("x").dtype == ctx.default_dtype assert ctx.get_symbol("y").dtype == ctx.default_dtype - expr = freeze(p - q) + expr = freeze(p - q) assert ctx.get_symbol("p").dtype == ctx.index_dtype assert ctx.get_symbol("q").dtype == ctx.index_dtype @@ -309,3 +319,29 @@ def test_cast_func(): expr = freeze(CastFunc(42, create_type("int16"))) assert expr.structurally_equal(PsConstantExpr(PsConstant(42, create_type("int16")))) + + +def test_add_sub(): + ctx = KernelCreationContext() + freeze = FreezeExpressions(ctx) + + x = sp.Symbol("x") + y = sp.Symbol("y", negative=True) + + x2 = PsExpression.make(ctx.get_symbol("x")) + y2 = PsExpression.make(ctx.get_symbol("y")) + + two = PsExpression.make(PsConstant(2)) + minus_two = PsExpression.make(PsConstant(-2)) + + expr = freeze(x + y) + assert expr.structurally_equal(PsAdd(x2, y2)) + + expr = freeze(x - y) + assert expr.structurally_equal(PsSub(x2, y2)) + + expr = freeze(x + 2 * y) + assert expr.structurally_equal(PsAdd(x2, PsMul(two, y2))) + + expr = freeze(x - 2 * y) + assert expr.structurally_equal(PsAdd(x2, PsMul(minus_two, y2))) -- GitLab