Skip to content
Snippets Groups Projects
Commit 919ed1c7 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Add literals test. Cleaned up some other tests.

parent 9bfaa437
No related branches found
No related tags found
1 merge request!378Customizability Extensions: CFunction Signatures and Code Literals
Pipeline #65466 passed
...@@ -37,6 +37,10 @@ class AstFactory: ...@@ -37,6 +37,10 @@ class AstFactory:
self._freeze = FreezeExpressions(ctx) self._freeze = FreezeExpressions(ctx)
self._typify = Typifier(ctx) self._typify = Typifier(ctx)
@overload
def parse_sympy(self, sp_obj: sp.Symbol) -> PsSymbolExpr:
pass
@overload @overload
def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression: def parse_sympy(self, sp_obj: sp.Expr) -> PsExpression:
pass pass
......
...@@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize(): ...@@ -59,5 +59,3 @@ def test_filter_kernel_fixedsize():
expected[1:-1, 1:-1].fill(18.0) expected[1:-1, 1:-1].fill(18.0)
np.testing.assert_allclose(dst_arr, expected) np.testing.assert_allclose(dst_arr, expected)
test_filter_kernel()
\ No newline at end of file
File moved
from pystencils import Target from pystencils import Target
from pystencils.backend.ast.expressions import PsExpression, PsArrayAccess from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock from pystencils.backend.ast.structural import PsAssignment, PsLoop, PsBlock
from pystencils.backend.kernelfunction import KernelFunction from pystencils.backend.kernelfunction import KernelFunction
from pystencils.backend.symbols import PsSymbol from pystencils.backend.symbols import PsSymbol
from pystencils.backend.constants import PsConstant from pystencils.backend.constants import PsConstant
from pystencils.backend.literals import PsLiteral
from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer from pystencils.backend.arrays import PsLinearizedArray, PsArrayBasePointer
from pystencils.types.quick import Fp, SInt, UInt, Bool from pystencils.types.quick import Fp, SInt, UInt, Bool
from pystencils.backend.emission import CAstPrinter from pystencils.backend.emission import CAstPrinter
# def test_basic_kernel():
# u_arr = PsLinearizedArray("u", Fp(64), (..., ), (1, ))
# u_size = PsExpression.make(u_arr.shape[0])
# u_base = PsArrayBasePointer("u_data", u_arr)
# loop_ctr = PsExpression.make(PsSymbol("ctr", UInt(32)))
# one = PsExpression.make(PsConstant(1, SInt(32)))
# update = PsAssignment(
# PsArrayAccess(u_base, loop_ctr),
# PsArrayAccess(u_base, loop_ctr + one) + PsArrayAccess(u_base, loop_ctr - one),
# )
# loop = PsLoop(
# loop_ctr,
# one,
# u_size - one,
# one,
# PsBlock([update])
# )
# func = KernelFunction(PsBlock([loop]), Target.CPU, "kernel", set())
# printer = CAstPrinter()
# code = printer(func)
# paramlist = func.get_parameters().params
# params_str = ", ".join(f"{p.dtype} {p.name}" for p in paramlist)
# assert code.find("(" + params_str + ")") >= 0
# assert code.find("u_data[ctr] = u_data[ctr + 1] + u_data[ctr - 1];") >= 0
def test_arithmetic_precedence(): def test_arithmetic_precedence():
(a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"] (a, b, c, d, e, f) = [PsExpression.make(PsSymbol(x, Fp(64))) for x in "abcdef"]
cprint = CAstPrinter() cprint = CAstPrinter()
......
import sympy as sp
from pystencils import make_slice, Field, Assignment
from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations
from pystencils.backend.literals import PsLiteral
from pystencils.backend.emission import CAstPrinter
from pystencils.backend.ast.expressions import PsExpression, PsSubscript
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.types.quick import Arr, Int
def test_literals():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
f = Field.create_generic("f", 3)
x = sp.Symbol("x")
cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64, const=True), 3)))
global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype))
loop_slice = make_slice[
0:PsSubscript(cells, factory.parse_index(0)),
0:PsSubscript(cells, factory.parse_index(1)),
0:PsSubscript(cells, factory.parse_index(2)),
]
ispace = FullIterationSpace.create_from_slice(ctx, loop_slice)
ctx.set_iteration_space(ispace)
x_decl = PsDeclaration(factory.parse_sympy(x), global_constant)
loop_body = PsBlock([
x_decl,
factory.parse_sympy(Assignment(f.center(), x))
])
loops = factory.loops_from_ispace(ispace, loop_body)
ast = PsBlock([loops])
canon = CanonicalizeSymbols(ctx)
ast = canon(ast)
hoist = HoistLoopInvariantDeclarations(ctx)
ast = hoist(ast)
assert isinstance(ast, PsBlock)
assert len(ast.statements) == 2
assert ast.statements[0] == x_decl
code = CAstPrinter()(ast)
print(code)
assert "const double x = C;" in code
assert "CELLS[0]" in code
assert "CELLS[1]" in code
assert "CELLS[2]" in code
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment