From e12bef27071ab057c2b19c4bffe39c2c74d3f4c8 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 15 Jan 2024 15:09:02 +0100 Subject: [PATCH] basic printing test --- pystencils/nbackend/ast/__init__.py | 2 + pystencils/nbackend/ast/kernelfunction.py | 2 +- pystencils/nbackend/ast/nodes.py | 8 ++-- pystencils/nbackend/c_printer.py | 4 +- pystencils/nbackend/typed_expressions.py | 20 +++++++--- .../nbackend/test_basic_printing.py | 38 +++++++++++++++++++ 6 files changed, 61 insertions(+), 13 deletions(-) create mode 100644 pystencils_tests/nbackend/test_basic_printing.py diff --git a/pystencils/nbackend/ast/__init__.py b/pystencils/nbackend/ast/__init__.py index 95cb7831b..daee7214f 100644 --- a/pystencils/nbackend/ast/__init__.py +++ b/pystencils/nbackend/ast/__init__.py @@ -8,12 +8,14 @@ from .nodes import ( PsDeclaration, PsLoop, ) +from .kernelfunction import PsKernelFunction from .dispatcher import ast_visitor from .transformations import ast_subs __all__ = [ "ast_visitor", + "PsKernelFunction", "PsAstNode", "PsBlock", "PsExpression", diff --git a/pystencils/nbackend/ast/kernelfunction.py b/pystencils/nbackend/ast/kernelfunction.py index 6c9aad854..a12abb45a 100644 --- a/pystencils/nbackend/ast/kernelfunction.py +++ b/pystencils/nbackend/ast/kernelfunction.py @@ -8,7 +8,7 @@ from ...enums import Target class PsKernelFunction(PsAstNode): """A complete pystencils kernel function.""" - __match_args__ = ("block",) + __match_args__ = ("body",) def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): self._body = body diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index 30b1a4dd6..5418a1a37 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod import pymbolic.primitives as pb -from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue +from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant from .util import failing_cast @@ -87,15 +87,15 @@ class PsExpression(PsLeafNode): __match_args__ = ("expression",) - def __init__(self, expr: pb.Expression): + def __init__(self, expr: ExprOrConstant): self._expr = expr @property - def expression(self) -> pb.Expression: + def expression(self) -> ExprOrConstant: return self._expr @expression.setter - def expression(self, expr: pb.Expression): + def expression(self, expr: ExprOrConstant): self._expr = expr diff --git a/pystencils/nbackend/c_printer.py b/pystencils/nbackend/c_printer.py index 7c4bf4b7c..4ca472a29 100644 --- a/pystencils/nbackend/c_printer.py +++ b/pystencils/nbackend/c_printer.py @@ -18,11 +18,11 @@ class CPrinter: def indent(self, line): return " " * self._current_indent_level + line - def print(self, node: PsAstNode): + def print(self, node: PsAstNode) -> str: return self.visit(node) @ast_visitor - def visit(self, node: PsAstNode): + def visit(self, _: PsAstNode) -> str: raise ValueError("Cannot print this node.") @visit.case(PsKernelFunction) diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 4fe705615..62c8b7695 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -59,7 +59,7 @@ class PsLinearizedArray(PsArray): strides: Tuple[pb.Expression], element_type: PsScalarType, ): - length = reduce(lambda x, y: x * y, shape, 1) + length = reduce(lambda x, y: x * y, shape) super().__init__(name, length, element_type) self._shape = shape @@ -110,9 +110,6 @@ class PsArrayAccess(pb.Subscript): return self._base_ptr.array.element_type -PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] - - class PsTypedConstant: """Represents typed constants occuring in the pystencils AST. @@ -275,7 +272,11 @@ class PsTypedConstant: return PsTypedConstant(rem, self._dtype) def __neg__(self): - return PsTypedConstant(-self._value, self._dtype) + minus_one = PsTypedConstant(-1, self._dtype) + return pb.Product((minus_one, self)) + + def __bool__(self): + return bool(self._value) def __eq__(self, other: object) -> bool: if not isinstance(other, PsTypedConstant): @@ -287,4 +288,11 @@ class PsTypedConstant: return hash((self._value, self._dtype)) -pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,) +pb.register_constant_class(PsTypedConstant) + + +PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess] +"""Types of expressions that may occur on the left-hand side of assignments.""" + +ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant +"""Required since `PsTypedConstant` does not derive from `pb.Expression`.""" diff --git a/pystencils_tests/nbackend/test_basic_printing.py b/pystencils_tests/nbackend/test_basic_printing.py new file mode 100644 index 000000000..2394b9287 --- /dev/null +++ b/pystencils_tests/nbackend/test_basic_printing.py @@ -0,0 +1,38 @@ +import pytest + +from pystencils import Target + +from pystencils.nbackend.ast import * +from pystencils.nbackend.typed_expressions import * +from pystencils.nbackend.types.quick import * +from pystencils.nbackend.c_printer import CPrinter + +def test_basic_kernel(): + + u_size = PsTypedVariable("u_length", UInt(32, True)) + u_arr = PsArray("u", u_size, Fp(64)) + u_base = PsArrayBasePointer("u_data", u_arr) + + loop_ctr = PsTypedVariable("ctr", UInt(32)) + one = PsTypedConstant(1, SInt(32)) + + update = PsAssignment( + PsLvalueExpr(PsArrayAccess(u_base, loop_ctr)), + PsExpression(PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one)), + ) + + loop = PsLoop( + PsSymbolExpr(loop_ctr), + PsExpression(one), + PsExpression(u_size - one), + PsExpression(one), + PsBlock([update]) + ) + + func = PsKernelFunction(PsBlock([loop]), target=Target.CPU) + + printer = CPrinter() + code = printer.print(func) + + assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1]") >= 0 + -- GitLab