diff --git a/src/pystencils/backend/kernelcreation/ast_factory.py b/src/pystencils/backend/kernelcreation/ast_factory.py index 83c406b0a99d52cee9599f321d6c32477f6dbf8a..d5695be93a7a89134f3bc7a12f623006768579d1 100644 --- a/src/pystencils/backend/kernelcreation/ast_factory.py +++ b/src/pystencils/backend/kernelcreation/ast_factory.py @@ -37,6 +37,10 @@ class AstFactory: self._freeze = FreezeExpressions(ctx) self._typify = Typifier(ctx) + @overload + def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr: + pass + @overload def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: pass diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 29744c384e03131784c08857c494bb7f83e7f0bd..9ce2f661d840641d28774134070fc7050e90e6d1 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize(): expected[1:-1, 1:-1].fill(18.0) np.testing.assert_allclose(dst_arr, expected) - -test_filter_kernel() \ No newline at end of file diff --git a/tests/nbackend/test_ast_nodes.py b/tests/nbackend/test_ast.py similarity index 100% rename from tests/nbackend/test_ast_nodes.py rename to tests/nbackend/test_ast.py diff --git a/tests/nbackend/test_code_printing.py b/tests/nbackend/test_code_printing.py index 1fc6821d7b530a8b8e10b0298b641219bd31a53a..4c83e6e995f0823f81a4627e93d38256f648d28c 100644 --- a/tests/nbackend/test_code_printing.py +++ b/tests/nbackend/test_code_printing.py @@ -1,49 +1,16 @@ from pystencils import Target -from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess +from pystencils.backend.ast.expressions import PsExpression from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.symbols import PsSymbol from pystencils.backend.constants import PsConstant +from pystencils.backend.literals import PsLiteral from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.backend.emission import CAstPrinter -# def test_basic_kernel(): - -# u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, )) -# u_size = PsExpression.make(u_arr.shape[0]) -# u_base = PsArrayBasePointer("u_data", u_arr) - -# loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32))) -# one = PsExpression.make(PsConstant(1, SInt(32))) - -# update = PsAssignment( -# PsArrayAccess(u_base, loop_ctr), -# PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one), -# ) - -# loop = PsLoop( -# loop_ctr, -# one, -# u_size - one, -# one, -# PsBlock([update]) -# ) - -# func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set()) - -# printer = CAstPrinter() -# code = printer(func) - -# paramlist = func.get_parameters().params -# params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist) - -# assert code.find("(" + params_str + ")") >= 0 -# assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0 - - def test_arithmetic_precedence(): (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"] cprint = CAstPrinter() diff --git a/tests/nbackend/test_extensions.py b/tests/nbackend/test_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..75726a3512b1291531ecf73a61af22258d17003c --- /dev/null +++ b/tests/nbackend/test_extensions.py @@ -0,0 +1,59 @@ + +import sympy as sp + +from pystencils import make_slice, Field, Assignment +from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace +from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations +from pystencils.backend.literals import PsLiteral +from pystencils.backend.emission import CAstPrinter +from pystencils.backend.ast.expressions import PsExpression, PsSubscript +from pystencils.backend.ast.structural import PsBlock, PsDeclaration +from pystencils.types.quick import Arr, Int + + +def test_literals(): + ctx = KernelCreationContext() + factory = AstFactory(ctx) + + f = Field.create_generic("f", 3) + x = sp.Symbol("x") + + cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3))) + global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype)) + + loop_slice = make_slice[ + 0:PsSubscript(cells, factory.parse_index(0)), + 0:PsSubscript(cells, factory.parse_index(1)), + 0:PsSubscript(cells, factory.parse_index(2)), + ] + + ispace = FullIterationSpace.create_from_slice(ctx, loop_slice) + ctx.set_iteration_space(ispace) + + x_decl = PsDeclaration(factory.parse_sympy(x), global_constant) + + loop_body = PsBlock([ + x_decl, + factory.parse_sympy(Assignment(f.center(), x)) + ]) + + loops = factory.loops_from_ispace(ispace, loop_body) + ast = PsBlock([loops]) + + canon = CanonicalizeSymbols(ctx) + ast = canon(ast) + + hoist = HoistLoopInvariantDeclarations(ctx) + ast = hoist(ast) + + assert isinstance(ast, PsBlock) + assert len(ast.statements) == 2 + assert ast.statements[0] == x_decl + + code = CAstPrinter()(ast) + print(code) + + assert "const double x = C;" in code + assert "CELLS[0]" in code + assert "CELLS[1]" in code + assert "CELLS[2]" in code