From 919ed1c7ce8a1e6abedc6cc75e9810926ad78db3 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Fri, 26 Apr 2024 17:36:09 +0200 Subject: [PATCH] Add literals test. Cleaned up some other tests. --- .../backend/kernelcreation/ast_factory.py | 4 ++ .../kernelcreation/test_domain_kernels.py | 2 - .../{test_ast_nodes.py => test_ast.py} | 0 tests/nbackend/test_code_printing.py | 37 +----------- tests/nbackend/test_extensions.py | 59 +++++++++++++++++++ 5 files changed, 65 insertions(+), 37 deletions(-) rename tests/nbackend/{test_ast_nodes.py => test_ast.py} (100%) create mode 100644 tests/nbackend/test_extensions.py diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 83c406b0a..d5695be93 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -37,6 +37,10 @@ class AstFactory: self._freeze = FreezeExpressions(ctx) self._typify = Typifier(ctx) + @overload + def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr: + pass + @overload def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: pass diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 29744c384..9ce2f661d 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize(): expected[1:-1, 1:-1].fill(18.0) np.testing.assert_allclose(dst_arr, expected) - -test_filter_kernel() \ No newline at end of file diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast.py similarity index 100% rename from tests/nbackend/test_ast_nodes.py rename to tests/nbackend/test_ast.py diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 1fc6821d7..4c83e6e99 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -1,49 +1,16 @@ from pystencils import Target -from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess +from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant +from pystencils.backend.literals import PsLiteral from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter -# def test_basic_kernel(): - -# u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, )) -# u_size = PsExpression.make(u_arr.shape[0]) -# u_base = PsArrayBasePointer("u_data", u_arr) - -# loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32))) -# one = PsExpression.make(PsConstant(1, SInt(32))) - -# update = PsAssignment( -# PsArrayAccess(u_base, loop_ctr), -# PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one), -# ) - -# loop = PsLoop( -# loop_ctr, -# one, -# u_size - one, -# one, -# PsBlock([update]) -# ) - -# func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set()) - -# printer = CAstPrinter() -# code = printer(func) - -# paramlist = func.get_parameters().params -# params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) - -# assert code.find("(" + params_str + ")") >= 0 -# assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0 - - def test_arithmetic_precedence(): (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"] cprint = CAstPrinter() diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py new file mode 100644 index 000000000..75726a351 --- /dev/null +++ b/tests/nbackend/test_extensions.py @@ -0,0 +1,59 @@ + +import sympy as sp + +from pystencils import make_slice, Field, Assignment +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.literals import PsLiteral +from pystencils.backend.emission import CAstPrinter +from pystencils.backend.ast.expressions import PsExpression, PsSubscript +from pystencils.backend.ast.structural import PsBlock, PsDeclaration +from pystencils.types.quick import Arr, Int + + +def test_literals(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f = Field.create_generic("f", 3) + x = sp.Symbol("x") + + cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3))) + global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype)) + + loop_slice = make_slice[ + 0:PsSubscript(cells, factory.parse_index(0)), + 0:PsSubscript(cells, factory.parse_index(1)), + 0:PsSubscript(cells, factory.parse_index(2)), + ] + + ispace = FullIterationSpace.create_from_slice(ctx, loop_slice) + ctx.set_iteration_space(ispace) + + x_decl = PsDeclaration(factory.parse_sympy(x), global_constant) + + loop_body = PsBlock([ + x_decl, + factory.parse_sympy(Assignment(f.center(), x)) + ]) + + loops = factory.loops_from_ispace(ispace, loop_body) + ast = PsBlock([loops]) + + canon = CanonicalizeSymbols(ctx) + ast = canon(ast) + + hoist = HoistLoopInvariantDeclarations(ctx) + ast = hoist(ast) + + assert isinstance(ast, PsBlock) + assert len(ast.statements) == 2 + assert ast.statements[0] == x_decl + + code = CAstPrinter()(ast) + print(code) + + assert "const double x = C;" in code + assert "CELLS[0]" in code + assert "CELLS[1]" in code + assert "CELLS[2]" in code -- GitLab