Skip to content
Snippets Groups Projects
Commit e12bef27 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

basic printing test

parent 25d72d18
No related branches found
No related tags found
No related merge requests found
Pipeline #60142 failed
...@@ -8,12 +8,14 @@ from .nodes import ( ...@@ -8,12 +8,14 @@ from .nodes import (
PsDeclaration, PsDeclaration,
PsLoop, PsLoop,
) )
from .kernelfunction import PsKernelFunction
from .dispatcher import ast_visitor from .dispatcher import ast_visitor
from .transformations import ast_subs from .transformations import ast_subs
__all__ = [ __all__ = [
"ast_visitor", "ast_visitor",
"PsKernelFunction",
"PsAstNode", "PsAstNode",
"PsBlock", "PsBlock",
"PsExpression", "PsExpression",
......
...@@ -8,7 +8,7 @@ from ...enums import Target ...@@ -8,7 +8,7 @@ from ...enums import Target
class PsKernelFunction(PsAstNode): class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function.""" """A complete pystencils kernel function."""
__match_args__ = ("block",) __match_args__ = ("body",)
def __init__(self, body: PsBlock, target: Target, name: str = "kernel"): def __init__(self, body: PsBlock, target: Target, name: str = "kernel"):
self._body = body self._body = body
......
...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
import pymbolic.primitives as pb import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
from .util import failing_cast from .util import failing_cast
...@@ -87,15 +87,15 @@ class PsExpression(PsLeafNode): ...@@ -87,15 +87,15 @@ class PsExpression(PsLeafNode):
__match_args__ = ("expression",) __match_args__ = ("expression",)
def __init__(self, expr: pb.Expression): def __init__(self, expr: ExprOrConstant):
self._expr = expr self._expr = expr
@property @property
def expression(self) -> pb.Expression: def expression(self) -> ExprOrConstant:
return self._expr return self._expr
@expression.setter @expression.setter
def expression(self, expr: pb.Expression): def expression(self, expr: ExprOrConstant):
self._expr = expr self._expr = expr
......
...@@ -18,11 +18,11 @@ class CPrinter: ...@@ -18,11 +18,11 @@ class CPrinter:
def indent(self, line): def indent(self, line):
return " " * self._current_indent_level + line return " " * self._current_indent_level + line
def print(self, node: PsAstNode): def print(self, node: PsAstNode) -> str:
return self.visit(node) return self.visit(node)
@ast_visitor @ast_visitor
def visit(self, node: PsAstNode): def visit(self, _: PsAstNode) -> str:
raise ValueError("Cannot print this node.") raise ValueError("Cannot print this node.")
@visit.case(PsKernelFunction) @visit.case(PsKernelFunction)
......
...@@ -59,7 +59,7 @@ class PsLinearizedArray(PsArray): ...@@ -59,7 +59,7 @@ class PsLinearizedArray(PsArray):
strides: Tuple[pb.Expression], strides: Tuple[pb.Expression],
element_type: PsScalarType, element_type: PsScalarType,
): ):
length = reduce(lambda x, y: x * y, shape, 1) length = reduce(lambda x, y: x * y, shape)
super().__init__(name, length, element_type) super().__init__(name, length, element_type)
self._shape = shape self._shape = shape
...@@ -110,9 +110,6 @@ class PsArrayAccess(pb.Subscript): ...@@ -110,9 +110,6 @@ class PsArrayAccess(pb.Subscript):
return self._base_ptr.array.element_type return self._base_ptr.array.element_type
PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
class PsTypedConstant: class PsTypedConstant:
"""Represents typed constants occuring in the pystencils AST. """Represents typed constants occuring in the pystencils AST.
...@@ -275,7 +272,11 @@ class PsTypedConstant: ...@@ -275,7 +272,11 @@ class PsTypedConstant:
return PsTypedConstant(rem, self._dtype) return PsTypedConstant(rem, self._dtype)
def __neg__(self): def __neg__(self):
return PsTypedConstant(-self._value, self._dtype) minus_one = PsTypedConstant(-1, self._dtype)
return pb.Product((minus_one, self))
def __bool__(self):
return bool(self._value)
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, PsTypedConstant): if not isinstance(other, PsTypedConstant):
...@@ -287,4 +288,11 @@ class PsTypedConstant: ...@@ -287,4 +288,11 @@ class PsTypedConstant:
return hash((self._value, self._dtype)) return hash((self._value, self._dtype))
pb.VALID_CONSTANT_CLASSES += (PsTypedConstant,) pb.register_constant_class(PsTypedConstant)
PsLvalue: TypeAlias = Union[PsTypedVariable, PsArrayAccess]
"""Types of expressions that may occur on the left-hand side of assignments."""
ExprOrConstant: TypeAlias = pb.Expression | PsTypedConstant
"""Required since `PsTypedConstant` does not derive from `pb.Expression`."""
import pytest
from pystencils import Target
from pystencils.nbackend.ast import *
from pystencils.nbackend.typed_expressions import *
from pystencils.nbackend.types.quick import *
from pystencils.nbackend.c_printer import CPrinter
def test_basic_kernel():
u_size = PsTypedVariable("u_length", UInt(32, True))
u_arr = PsArray("u", u_size, Fp(64))
u_base = PsArrayBasePointer("u_data", u_arr)
loop_ctr = PsTypedVariable("ctr", UInt(32))
one = PsTypedConstant(1, SInt(32))
update = PsAssignment(
PsLvalueExpr(PsArrayAccess(u_base, loop_ctr)),
PsExpression(PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one)),
)
loop = PsLoop(
PsSymbolExpr(loop_ctr),
PsExpression(one),
PsExpression(u_size - one),
PsExpression(one),
PsBlock([update])
)
func = PsKernelFunction(PsBlock([loop]), target=Target.CPU)
printer = CPrinter()
code = printer.print(func)
assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1]") >= 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment