From e12bef27071ab057c2b19c4bffe39c2c74d3f4c8 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 15 Jan 2024 15:09:02 +0100
Subject: [PATCH] basic printing test

---
 pystencils/nbackend/ast/__init__.py           |  2 +
 pystencils/nbackend/ast/kernelfunction.py     |  2 +-
 pystencils/nbackend/ast/nodes.py              |  8 ++--
 pystencils/nbackend/c_printer.py              |  4 +-
 pystencils/nbackend/typed_expressions.py      | 20 +++++++---
 .../nbackend/test_basic_printing.py           | 38 +++++++++++++++++++
 6 files changed, 61 insertions(+), 13 deletions(-)
 create mode 100644 pystencils_tests/nbackend/test_basic_printing.py

diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py
index 95cb7831b..daee7214f 100644
--- a/pystencils/nbackend/ast/__init__.py
+++ b/pystencils/nbackend/ast/__init__.py
@@ -8,12 +8,14 @@ from .nodes import (
     PsDeclaration,
     PsLoop,
 )
+from .kernelfunction import PsKernelFunction
 
 from .dispatcher import ast_visitor
 from .transformations import ast_subs
 
 __all__ = [
     "ast_visitor",
+    "PsKernelFunction",
     "PsAstNode",
     "PsBlock",
     "PsExpression",
diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py
index 6c9aad854..a12abb45a 100644
--- a/pystencils/nbackend/ast/kernelfunction.py
+++ b/pystencils/nbackend/ast/kernelfunction.py
@@ -8,7 +8,7 @@ from ...enums import Target
 class PsKernelFunction(PsAstNode):
     """A complete pystencils kernel function."""
 
-    __match_args__ = ("block",)
+    __match_args__ = ("body",)
 
     def __init__(self, body: PsBlock, target: Target, name: str = "kernel"):
         self._body = body
diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py
index 30b1a4dd6..5418a1a37 100644
--- a/pystencils/nbackend/ast/nodes.py
+++ b/pystencils/nbackend/ast/nodes.py
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
 
 import pymbolic.primitives as pb
 
-from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
+from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
 from .util import failing_cast
 
 
@@ -87,15 +87,15 @@ class PsExpression(PsLeafNode):
 
     __match_args__ = ("expression",)
 
-    def __init__(self, expr: pb.Expression):
+    def __init__(self, expr: ExprOrConstant):
         self._expr = expr
 
     @property
-    def expression(self) -> pb.Expression:
+    def expression(self) -> ExprOrConstant:
         return self._expr
 
     @expression.setter
-    def expression(self, expr: pb.Expression):
+    def expression(self, expr: ExprOrConstant):
         self._expr = expr
 
 
diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py
index 7c4bf4b7c..4ca472a29 100644
--- a/pystencils/nbackend/c_printer.py
+++ b/pystencils/nbackend/c_printer.py
@@ -18,11 +18,11 @@ class CPrinter:
     def indent(self, line):
         return " " * self._current_indent_level + line
 
-    def print(self, node: PsAstNode):
+    def print(self, node: PsAstNode) -> str:
         return self.visit(node)
 
     @ast_visitor
-    def visit(self, node: PsAstNode):
+    def visit(self, _: PsAstNode) -> str:
         raise ValueError("Cannot print this node.")
     
     @visit.case(PsKernelFunction)
diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py
index 4fe705615..62c8b7695 100644
--- a/pystencils/nbackend/typed_expressions.py
+++ b/pystencils/nbackend/typed_expressions.py
@@ -59,7 +59,7 @@ class PsLinearizedArray(PsArray):
         strides: Tuple[pb.Expression],
         element_type: PsScalarType,
     ):
-        length = reduce(lambda x, y: x * y, shape, 1)
+        length = reduce(lambda x, y: x * y, shape)
         super().__init__(name, length, element_type)
 
         self._shape = shape
@@ -110,9 +110,6 @@ class PsArrayAccess(pb.Subscript):
         return self._base_ptr.array.element_type
 
 
-PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
-
-
 class PsTypedConstant:
     """Represents typed constants occuring in the pystencils AST.
 
@@ -275,7 +272,11 @@ class PsTypedConstant:
             return PsTypedConstant(rem, self._dtype)
 
     def __neg__(self):
-        return PsTypedConstant(-self._value, self._dtype)
+        minus_one = PsTypedConstant(-1, self._dtype)
+        return pb.Product((minus_one, self))
+    
+    def __bool__(self):
+        return bool(self._value)
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, PsTypedConstant):
@@ -287,4 +288,11 @@ class PsTypedConstant:
         return hash((self._value, self._dtype))
 
 
-pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,)
+pb.register_constant_class(PsTypedConstant)
+
+
+PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
+"""Types of expressions that may occur on the left-hand side of assignments."""
+
+ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant
+"""Required since `PsTypedConstant` does not derive from `pb.Expression`."""
diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py
new file mode 100644
index 000000000..2394b9287
--- /dev/null
+++ b/pystencils_tests/nbackend/test_basic_printing.py
@@ -0,0 +1,38 @@
+import pytest
+
+from pystencils import Target
+
+from pystencils.nbackend.ast import *
+from pystencils.nbackend.typed_expressions import *
+from pystencils.nbackend.types.quick import *
+from pystencils.nbackend.c_printer import CPrinter
+
+def test_basic_kernel():
+
+    u_size = PsTypedVariable("u_length", UInt(32, True))
+    u_arr = PsArray("u", u_size, Fp(64))
+    u_base = PsArrayBasePointer("u_data", u_arr)
+
+    loop_ctr = PsTypedVariable("ctr", UInt(32))
+    one = PsTypedConstant(1, SInt(32))
+
+    update = PsAssignment(
+        PsLvalueExpr(PsArrayAccess(u_base, loop_ctr)),
+        PsExpression(PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one)),
+    )
+
+    loop = PsLoop(
+        PsSymbolExpr(loop_ctr),
+        PsExpression(one),
+        PsExpression(u_size - one),
+        PsExpression(one),
+        PsBlock([update])
+    )
+
+    func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
+
+    printer = CPrinter()
+    code = printer.print(func)
+
+    assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1]") >= 0
+
-- 
GitLab