From 558a0f20e082370a0bccd20b96a647e3536bc31e Mon Sep 17 00:00:00 2001
From: zy69guqi <richard.angersbach@fau.de>
Date: Wed, 15 Jan 2025 12:59:36 +0100
Subject: [PATCH] Expose new reduced assignments to pystencils interface

---
 src/pystencils/__init__.py             | 14 ++++++++++++++
 tests/kernelcreation/test_reduction.py |  4 ++--
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py
index 6cb375b61..eecd929cf 100644
--- a/src/pystencils/__init__.py
+++ b/src/pystencils/__init__.py
@@ -38,6 +38,14 @@ from .simp import AssignmentCollection
 from .sympyextensions.typed_sympy import TypedSymbol, DynamicType
 from .sympyextensions import SymbolCreator
 from .datahandling import create_data_handling
+from .sympyextensions.reduction import (
+    AddReducedAssignment,
+    SubReducedAssignment,
+    MulReducedAssignment,
+    DivReducedAssignment,
+    MinReducedssignment,
+    MaxReducedssignment
+)
 
 __all__ = [
     "Field",
@@ -69,6 +77,12 @@ __all__ = [
     "AssignmentCollection",
     "Assignment",
     "AddAugmentedAssignment",
+    "AddReducedAssignment",
+    "SubReducedAssignment",
+    "MulReducedAssignment",
+    "DivReducedAssignment",
+    "MinReducedssignment",
+    "MaxReducedssignment",
     "assignment_from_stencil",
     "SymbolCreator",
     "create_data_handling",
diff --git a/tests/kernelcreation/test_reduction.py b/tests/kernelcreation/test_reduction.py
index 47509e267..f8c2b1870 100644
--- a/tests/kernelcreation/test_reduction.py
+++ b/tests/kernelcreation/test_reduction.py
@@ -3,7 +3,7 @@ import numpy as np
 import sympy as sp
 
 import pystencils as ps
-from sympyextensions.reduction import reduced_assign
+from pystencils import AddReducedAssignment
 
 
 @pytest.mark.parametrize('dtype', ["float64", "float32"])
@@ -32,7 +32,7 @@ def test_log(dtype):
 
     omega = sp.Symbol("omega")
 
-    reduction_assignment = reduced_assign(omega, "+", x.center())
+    reduction_assignment = AddReducedAssignment(omega, x.center())
 
     ast_reduction = ps.create_kernel(reduction_assignment, default_dtype=dtype)
     code_reduction = ps.get_code_str(ast_reduction)
-- 
GitLab