Skip to content
Snippets Groups Projects
Commit 03b9e02a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Merge branch 'backend-rework' into fhennig/sycl

parents b85a6763 be07c320
No related branches found
No related tags found
1 merge request!384Fundamental GPU Support
Pipeline #66337 passed
...@@ -6,6 +6,7 @@ from .structural import ( ...@@ -6,6 +6,7 @@ from .structural import (
PsAstNode, PsAstNode,
PsBlock, PsBlock,
PsComment, PsComment,
PsConditional,
PsDeclaration, PsDeclaration,
PsExpression, PsExpression,
PsLoop, PsLoop,
...@@ -56,6 +57,12 @@ class UndefinedSymbolsCollector: ...@@ -56,6 +57,12 @@ class UndefinedSymbolsCollector:
undefined_vars.discard(ctr.symbol) undefined_vars.discard(ctr.symbol)
return undefined_vars return undefined_vars
case PsConditional(cond, branch_true, branch_false):
undefined_vars = self(cond) | self(branch_true)
if branch_false is not None:
undefined_vars |= self(branch_false)
return undefined_vars
case PsComment(): case PsComment():
return set() return set()
...@@ -86,6 +93,7 @@ class UndefinedSymbolsCollector: ...@@ -86,6 +93,7 @@ class UndefinedSymbolsCollector:
PsAssignment() PsAssignment()
| PsBlock() | PsBlock()
| PsComment() | PsComment()
| PsConditional()
| PsExpression() | PsExpression()
| PsLoop() | PsLoop()
| PsStatement() | PsStatement()
......
...@@ -12,6 +12,7 @@ from ..ast.structural import ( ...@@ -12,6 +12,7 @@ from ..ast.structural import (
PsDeclaration, PsDeclaration,
PsAssignment, PsAssignment,
PsComment, PsComment,
PsStatement,
) )
from ..ast.expressions import PsExpression, PsSymbolExpr from ..ast.expressions import PsExpression, PsSymbolExpr
...@@ -99,6 +100,9 @@ class CanonicalClone: ...@@ -99,6 +100,9 @@ class CanonicalClone:
self._replace_symbols(expr_clone, cc) self._replace_symbols(expr_clone, cc)
return cast(Node_T, expr_clone) return cast(Node_T, expr_clone)
case PsStatement(expr):
return cast(Node_T, PsStatement(self.visit(expr, cc)))
case _: case _:
raise PsInternalCompilerError( raise PsInternalCompilerError(
f"Don't know how to canonically clone {type(node)}" f"Don't know how to canonically clone {type(node)}"
......
...@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier ...@@ -4,7 +4,7 @@ from ..kernelcreation import KernelCreationContext, Typifier
from ..kernelcreation.ast_factory import AstFactory, IndexParsable from ..kernelcreation.ast_factory import AstFactory, IndexParsable
from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration from ..ast.structural import PsLoop, PsBlock, PsConditional, PsDeclaration
from ..ast.expressions import PsExpression, PsConstantExpr, PsLt from ..ast.expressions import PsExpression, PsConstantExpr, PsGe, PsLt
from ..constants import PsConstant from ..constants import PsConstant
from .canonical_clone import CanonicalClone, CloneContext from .canonical_clone import CanonicalClone, CloneContext
...@@ -48,7 +48,9 @@ class ReshapeLoops: ...@@ -48,7 +48,9 @@ class ReshapeLoops:
peeled_ctr = self._factory.parse_index( peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol) cc.get_replacement(loop.counter.symbol)
) )
peeled_idx = self._typify(loop.start + PsExpression.make(PsConstant(i))) peeled_idx = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(i)) * loop.step)
)
counter_decl = PsDeclaration(peeled_ctr, peeled_idx) counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc) peeled_block = self._canon_clone.visit(loop.body, cc)
...@@ -65,11 +67,71 @@ class ReshapeLoops: ...@@ -65,11 +67,71 @@ class ReshapeLoops:
peeled_iters.append(peeled_block) peeled_iters.append(peeled_block)
loop.start = self._elim_constants( loop.start = self._elim_constants(
self._typify(loop.start + PsExpression.make(PsConstant(num_iterations))) self._typify(
loop.start + PsExpression.make(PsConstant(num_iterations)) * loop.step
)
) )
return peeled_iters, loop return peeled_iters, loop
def peel_loop_back(
self, loop: PsLoop, num_iterations: int, omit_range_check: bool = False
) -> tuple[PsLoop, Sequence[PsBlock]]:
"""Peel off iterations from the back of a loop.
Removes ``num_iterations`` from the back of the given loop and returns them as a sequence of
independent blocks.
Args:
loop: The loop node from which to peel iterations
num_iterations: The number of iterations to peel off
omit_range_check: If set to `True`, assume that the peeled-off iterations will always
be executed, and omit their enclosing conditional.
Returns:
Tuple containing the modified loop and the peeled-off iterations (sequence of blocks).
"""
if not (
isinstance(loop.step, PsConstantExpr) and loop.step.constant.value == 1
):
raise NotImplementedError(
"Peeling iterations from the back of loops is only implemented"
"for loops with unit step. Implementation is deferred until"
"loop range canonicalization is available (also needed for the"
"vectorizer)."
)
peeled_iters: list[PsBlock] = []
for i in range(num_iterations)[::-1]:
cc = CloneContext(self._ctx)
cc.symbol_decl(loop.counter.symbol)
peeled_ctr = self._factory.parse_index(
cc.get_replacement(loop.counter.symbol)
)
peeled_idx = self._typify(loop.stop - PsExpression.make(PsConstant(i + 1)))
counter_decl = PsDeclaration(peeled_ctr, peeled_idx)
peeled_block = self._canon_clone.visit(loop.body, cc)
if omit_range_check:
peeled_block.statements = [counter_decl] + peeled_block.statements
else:
iter_condition = PsGe(peeled_ctr, loop.start)
peeled_block.statements = [
counter_decl,
PsConditional(iter_condition, PsBlock(peeled_block.statements)),
]
peeled_iters.append(peeled_block)
loop.stop = self._elim_constants(
self._typify(loop.stop - PsExpression.make(PsConstant(num_iterations)))
)
return loop, peeled_iters
def cut_loop( def cut_loop(
self, loop: PsLoop, cutting_points: Sequence[IndexParsable] self, loop: PsLoop, cutting_points: Sequence[IndexParsable]
) -> Sequence[PsLoop | PsBlock]: ) -> Sequence[PsLoop | PsBlock]:
......
...@@ -483,8 +483,7 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -483,8 +483,7 @@ class PsIntegerType(PsScalarType, ABC):
if not isinstance(value, np_dtype): if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}") raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
unsigned_suffix = "" if self.signed else "u" unsigned_suffix = "" if self.signed else "u"
# TODO: cast literal to correct type? return f"(({self._c_type_without_const()}) {value}{unsigned_suffix})"
return str(value) + unsigned_suffix
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width] np_type = self.NUMPY_TYPES[self._width]
...@@ -499,9 +498,12 @@ class PsIntegerType(PsScalarType, ABC): ...@@ -499,9 +498,12 @@ class PsIntegerType(PsScalarType, ABC):
raise PsTypeError(f"Could not interpret {value} as {repr(self)}") raise PsTypeError(f"Could not interpret {value} as {repr(self)}")
def c_string(self) -> str: def _c_type_without_const(self) -> str:
prefix = "" if self._signed else "u" prefix = "" if self._signed else "u"
return f"{self._const_string()}{prefix}int{self._width}_t" return f"{prefix}int{self._width}_t"
def c_string(self) -> str:
return f"{self._const_string()}{self._c_type_without_const()}"
def __repr__(self) -> str: def __repr__(self) -> str:
return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )" return f"PsIntegerType( width={self.width}, signed={self.signed}, const={self.const} )"
......
...@@ -54,6 +54,6 @@ def test_literals(): ...@@ -54,6 +54,6 @@ def test_literals():
print(code) print(code)
assert "const double x = C;" in code assert "const double x = C;" in code
assert "CELLS[0]" in code assert "CELLS[((int64_t) 0)]" in code
assert "CELLS[1]" in code assert "CELLS[((int64_t) 1)]" in code
assert "CELLS[2]" in code assert "CELLS[((int64_t) 2)]" in code
...@@ -8,8 +8,13 @@ from pystencils.backend.kernelcreation import ( ...@@ -8,8 +8,13 @@ from pystencils.backend.kernelcreation import (
) )
from pystencils.backend.transformations import ReshapeLoops from pystencils.backend.transformations import ReshapeLoops
from pystencils.backend.ast.structural import PsDeclaration, PsBlock, PsLoop, PsConditional from pystencils.backend.ast.structural import (
from pystencils.backend.ast.expressions import PsConstantExpr, PsLt PsDeclaration,
PsBlock,
PsLoop,
PsConditional,
)
from pystencils.backend.ast.expressions import PsConstantExpr, PsGe, PsLt
def test_loop_cutting(): def test_loop_cutting():
...@@ -43,10 +48,12 @@ def test_loop_cutting(): ...@@ -43,10 +48,12 @@ def test_loop_cutting():
x_decl = subloop.statements[1] x_decl = subloop.statements[1]
assert isinstance(x_decl, PsDeclaration) assert isinstance(x_decl, PsDeclaration)
assert x_decl.declared_symbol.name == "x__0" assert x_decl.declared_symbol.name == "x__0"
subloop = subloops[1] subloop = subloops[1]
assert isinstance(subloop, PsLoop) assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1 assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 1
)
assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3 assert isinstance(subloop.stop, PsConstantExpr) and subloop.stop.constant.value == 3
x_decl = subloop.body.statements[0] x_decl = subloop.body.statements[0]
...@@ -55,7 +62,9 @@ def test_loop_cutting(): ...@@ -55,7 +62,9 @@ def test_loop_cutting():
subloop = subloops[2] subloop = subloops[2]
assert isinstance(subloop, PsLoop) assert isinstance(subloop, PsLoop)
assert isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3 assert (
isinstance(subloop.start, PsConstantExpr) and subloop.start.constant.value == 3
)
assert subloop.stop.structurally_equal(loop.stop) assert subloop.stop.structurally_equal(loop.stop)
...@@ -67,19 +76,23 @@ def test_loop_peeling(): ...@@ -67,19 +76,23 @@ def test_loop_peeling():
x, y, z = sp.symbols("x, y, z") x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,)) f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f) ispace = FullIterationSpace.create_from_slice(
ctx, slice(2, 11, 3), archetype_field=f
)
ctx.set_iteration_space(ispace) ctx.set_iteration_space(ispace)
loop_body = PsBlock([ loop_body = PsBlock(
factory.parse_sympy(Assignment(x, 2 * z)), [
factory.parse_sympy(Assignment(f.center(0), x + y)), factory.parse_sympy(Assignment(x, 2 * z)),
]) factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body) loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3 num_iters = 2
peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters) peeled_iters, peeled_loop = reshape.peel_loop_front(loop, num_iters)
assert len(peeled_iters) == 3 assert len(peeled_iters) == num_iters
for i, iter in enumerate(peeled_iters): for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock) assert isinstance(iter, PsBlock)
...@@ -87,6 +100,8 @@ def test_loop_peeling(): ...@@ -87,6 +100,8 @@ def test_loop_peeling():
ctr_decl = iter.statements[0] ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration) assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}" assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
ctr_value = {0: 2, 1: 5}[i]
assert ctr_decl.rhs.structurally_equal(factory.parse_index(ctr_value))
cond = iter.statements[1] cond = iter.statements[1]
assert isinstance(cond, PsConditional) assert isinstance(cond, PsConditional)
...@@ -96,6 +111,53 @@ def test_loop_peeling(): ...@@ -96,6 +111,53 @@ def test_loop_peeling():
assert isinstance(subblock.statements[0], PsDeclaration) assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}" assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(factory.parse_index(num_iters)) assert peeled_loop.start.structurally_equal(factory.parse_index(8))
assert peeled_loop.stop.structurally_equal(loop.stop) assert peeled_loop.stop.structurally_equal(loop.stop)
assert peeled_loop.body.structurally_equal(loop.body) assert peeled_loop.body.structurally_equal(loop.body)
def test_loop_peeling_back():
ctx = KernelCreationContext()
factory = AstFactory(ctx)
reshape = ReshapeLoops(ctx)
x, y, z = sp.symbols("x, y, z")
f = Field.create_generic("f", 1, index_shape=(2,))
ispace = FullIterationSpace.create_from_slice(ctx, make_slice[:], archetype_field=f)
ctx.set_iteration_space(ispace)
loop_body = PsBlock(
[
factory.parse_sympy(Assignment(x, 2 * z)),
factory.parse_sympy(Assignment(f.center(0), x + y)),
]
)
loop = factory.loops_from_ispace(ispace, loop_body)
num_iters = 3
peeled_loop, peeled_iters = reshape.peel_loop_back(loop, num_iters)
assert len(peeled_iters) == 3
for i, iter in enumerate(peeled_iters):
assert isinstance(iter, PsBlock)
ctr_decl = iter.statements[0]
assert isinstance(ctr_decl, PsDeclaration)
assert ctr_decl.declared_symbol.name == f"ctr_0__{i}"
cond = iter.statements[1]
assert isinstance(cond, PsConditional)
assert cond.condition.structurally_equal(PsGe(ctr_decl.lhs, loop.start))
subblock = cond.branch_true
assert isinstance(subblock.statements[0], PsDeclaration)
assert subblock.statements[0].declared_symbol.name == f"x__{i}"
assert peeled_loop.start.structurally_equal(loop.start)
assert peeled_loop.stop.structurally_equal(
factory.loops_from_ispace(ispace, loop_body).stop
- factory.parse_index(num_iters)
)
assert peeled_loop.body.structurally_equal(loop.body)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment