-
Frederik Hennig authoredFrederik Hennig authored
test_freeze.py 1.81 KiB
import sympy as sp
import pymbolic.primitives as pb
from pystencils import Assignment, fields
from pystencils.nbackend.ast import (
PsAssignment,
PsDeclaration,
PsExpression,
PsSymbolExpr,
PsLvalueExpr,
)
from pystencils.nbackend.typed_expressions import PsTypedConstant, PsTypedVariable
from pystencils.nbackend.arrays import PsArrayAccess
from pystencils.nbackend.kernelcreation import (
KernelCreationOptions,
KernelCreationContext,
FreezeExpressions,
FullIterationSpace,
)
def test_freeze_simple():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
freeze = FreezeExpressions(ctx)
x, y, z = sp.symbols("x, y, z")
asm = Assignment(z, 2 * x + y)
fasm = freeze(asm)
pb_x, pb_y, pb_z = pb.variables("x y z")
assert fasm == PsDeclaration(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
assert fasm != PsAssignment(PsSymbolExpr(pb_z), PsExpression(pb_y + 2 * pb_x))
def test_freeze_fields():
options = KernelCreationOptions()
ctx = KernelCreationContext(options)
start = PsTypedConstant(0, ctx.index_dtype)
stop = PsTypedConstant(42, ctx.index_dtype)
step = PsTypedConstant(1, ctx.index_dtype)
counter = PsTypedVariable("ctr", ctx.index_dtype)
ispace = FullIterationSpace(
ctx, [FullIterationSpace.Dimension(start, stop, step, counter)]
)
ctx.set_iteration_space(ispace)
freeze = FreezeExpressions(ctx)
f, g = fields("f, g : [1D]")
asm = Assignment(f.center(0), g.center(0))
f_arr = ctx.get_array(f)
g_arr = ctx.get_array(g)
fasm = freeze(asm)
lhs = PsArrayAccess(f_arr.base_pointer, counter * f_arr.strides[0])
rhs = PsArrayAccess(g_arr.base_pointer, counter * g_arr.strides[0])
should = PsAssignment(PsLvalueExpr(lhs), PsExpression(rhs))
assert fasm == should