From 919ed1c7ce8a1e6abedc6cc75e9810926ad78db3 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Fri, 26 Apr 2024 17:36:09 +0200
Subject: [PATCH] Add literals test. Cleaned up some other tests.

---
 .../backend/kernelcreation/ast_factory.py     |  4 ++
 .../kernelcreation/test_domain_kernels.py     |  2 -
 .../{test_ast_nodes.py => test_ast.py}        |  0
 tests/nbackend/test_code_printing.py          | 37 +-----------
 tests/nbackend/test_extensions.py             | 59 +++++++++++++++++++
 5 files changed, 65 insertions(+), 37 deletions(-)
 rename tests/nbackend/{test_ast_nodes.py => test_ast.py} (100%)
 create mode 100644 tests/nbackend/test_extensions.py

diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py
index 83c406b0a..d5695be93 100644
--- a/src/pystencils/backend/kernelcreation/ast_factory.py
+++ b/src/pystencils/backend/kernelcreation/ast_factory.py
@@ -37,6 +37,10 @@ class AstFactory:
         self._freeze = FreezeExpressions(ctx)
         self._typify = Typifier(ctx)
 
+    @overload
+    def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr:
+        pass
+
     @overload
     def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression:
         pass
diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py
index 29744c384..9ce2f661d 100644
--- a/tests/nbackend/kernelcreation/test_domain_kernels.py
+++ b/tests/nbackend/kernelcreation/test_domain_kernels.py
@@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize():
     expected[1:-1, 1:-1].fill(18.0)
 
     np.testing.assert_allclose(dst_arr, expected)
-
-test_filter_kernel()
\ No newline at end of file
diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast.py
similarity index 100%
rename from tests/nbackend/test_ast_nodes.py
rename to tests/nbackend/test_ast.py
diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py
index 1fc6821d7..4c83e6e99 100644
--- a/tests/nbackend/test_code_printing.py
+++ b/tests/nbackend/test_code_printing.py
@@ -1,49 +1,16 @@
 from pystencils import Target
 
-from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess
+from pystencils.backend.ast.expressions import PsExpression
 from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock
 from pystencils.backend.kernelfunction import KernelFunction
 from pystencils.backend.symbols import PsSymbol
 from pystencils.backend.constants import PsConstant
+from pystencils.backend.literals import PsLiteral
 from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
 from pystencils.types.quick import Fp, SInt, UInt, Bool
 from pystencils.backend.emission import CAstPrinter
 
 
-# def test_basic_kernel():
-
-#     u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
-#     u_size = PsExpression.make(u_arr.shape[0])
-#     u_base = PsArrayBasePointer("u_data", u_arr)
-
-#     loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32)))
-#     one = PsExpression.make(PsConstant(1, SInt(32)))
-
-#     update = PsAssignment(
-#         PsArrayAccess(u_base, loop_ctr),
-#         PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one),
-#     )
-
-#     loop = PsLoop(
-#         loop_ctr,
-#         one,
-#         u_size - one,
-#         one,
-#         PsBlock([update])
-#     )
-
-#     func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set())
-
-#     printer = CAstPrinter()
-#     code = printer(func)
-
-#     paramlist = func.get_parameters().params
-#     params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist)
-
-#     assert code.find("(" + params_str + ")") >= 0
-#     assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0
-
-
 def test_arithmetic_precedence():
     (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"]
     cprint = CAstPrinter()
diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py
new file mode 100644
index 000000000..75726a351
--- /dev/null
+++ b/tests/nbackend/test_extensions.py
@@ -0,0 +1,59 @@
+
+import sympy as sp
+
+from pystencils import make_slice, Field, Assignment
+from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
+from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
+from pystencils.backend.literals import PsLiteral
+from pystencils.backend.emission import CAstPrinter
+from pystencils.backend.ast.expressions import PsExpression, PsSubscript
+from pystencils.backend.ast.structural import PsBlock, PsDeclaration
+from pystencils.types.quick import Arr, Int
+
+
+def test_literals():
+    ctx = KernelCreationContext()
+    factory = AstFactory(ctx)
+
+    f = Field.create_generic("f", 3)
+    x = sp.Symbol("x")
+    
+    cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3)))
+    global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype))
+
+    loop_slice = make_slice[
+        0:PsSubscript(cells, factory.parse_index(0)),
+        0:PsSubscript(cells, factory.parse_index(1)),
+        0:PsSubscript(cells, factory.parse_index(2)),
+    ]
+
+    ispace = FullIterationSpace.create_from_slice(ctx, loop_slice)
+    ctx.set_iteration_space(ispace)
+    
+    x_decl = PsDeclaration(factory.parse_sympy(x), global_constant)
+
+    loop_body = PsBlock([
+        x_decl,
+        factory.parse_sympy(Assignment(f.center(), x))
+    ])
+
+    loops = factory.loops_from_ispace(ispace, loop_body)
+    ast = PsBlock([loops])
+    
+    canon = CanonicalizeSymbols(ctx)
+    ast = canon(ast)
+
+    hoist = HoistLoopInvariantDeclarations(ctx)
+    ast = hoist(ast)
+
+    assert isinstance(ast, PsBlock)
+    assert len(ast.statements) == 2
+    assert ast.statements[0] == x_decl
+
+    code = CAstPrinter()(ast)
+    print(code)
+
+    assert "const double x = C;" in code
+    assert "CELLS[0]" in code
+    assert "CELLS[1]" in code
+    assert "CELLS[2]" in code
-- 
GitLab