Skip to content
Snippets Groups Projects
Commit a04c241b authored by Frederik Hennig's avatar Frederik Hennig
Browse files

A series of bug fixes

 - Add another canonicalize pass after lowering
 - Fix typification of array decl RHS
 - Fix: LowerToC did not descend into PsBufferAcc index args
parent 8532ccdb
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
Pipeline #69775 passed
......@@ -451,6 +451,9 @@ class PsLookup(PsExpression, PsLvalue):
idx = [0][idx]
self._aggregate = failing_cast(PsExpression, c)
def __repr__(self) -> str:
return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})"
class PsCall(PsExpression):
__match_args__ = ("function", "args")
......
......@@ -344,6 +344,8 @@ class Typifier:
decl_tc.apply_dtype(
PsArrayType(items_tc.target_type, rhs.shape), rhs
)
else:
decl_tc.infer_dtype(rhs)
case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs):
# Only if the LHS is an untyped symbol, infer its type from the RHS
......
from __future__ import annotations
from typing import ClassVar, Sequence
from typing import Sequence
from itertools import chain
from dataclasses import dataclass
......@@ -86,17 +86,15 @@ class PsSymbol:
return f"{self._name}: {dtype_str}"
def __repr__(self) -> str:
return f"PsSymbol({self._name}, {self._dtype})"
return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})"
@dataclass(frozen=True)
class BufferBasePtr(PsSymbolProperty):
class BufferBasePtr(UniqueSymbolProperty):
"""Symbol acts as a base pointer to a buffer."""
buffer: PsBuffer
_unique: ClassVar[bool] = True
class PsBuffer:
"""N-dimensional contiguous linearized buffer in heap memory.
......
......@@ -65,7 +65,7 @@ class LowerToC:
return i
summands: list[PsExpression] = [
maybe_cast(idx) * PsExpression.make(stride)
maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride)
for idx, stride in zip(indices, buf.strides, strict=True)
]
......
......@@ -22,6 +22,7 @@ from .backend.transformations import (
EliminateConstants,
LowerToC,
SelectFunctions,
CanonicalizeSymbols,
)
from .backend.kernelfunction import (
create_cpu_kernel_function,
......@@ -143,12 +144,17 @@ def create_kernel(
kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim)
erase_anons = LowerToC(ctx)
kernel_ast = cast(PsBlock, erase_anons(kernel_ast))
# Lowering
lower_to_c = LowerToC(ctx)
kernel_ast = cast(PsBlock, lower_to_c(kernel_ast))
select_functions = SelectFunctions(platform)
kernel_ast = cast(PsBlock, select_functions(kernel_ast))
# Lowering introduces new symbols, which have to be canonicalized
canonicalize = CanonicalizeSymbols(ctx, True)
kernel_ast = cast(PsBlock, canonicalize(kernel_ast))
if config.target.is_cpu():
return create_cpu_kernel_function(
ctx,
......
......@@ -91,7 +91,8 @@ class PsPointerType(PsDereferencableType):
def c_string(self) -> str:
base_str = self._base_type.c_string()
restrict_str = " RESTRICT" if self._restrict else ""
return f"{base_str} *{restrict_str} {self._const_string()}"
const_str = " const" if self.const else ""
return f"{base_str} *{restrict_str}{const_str}"
def __repr__(self) -> str:
return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )"
......
......@@ -371,6 +371,7 @@ def test_array_declarations():
decl = typify(decl)
assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,))
# Array type determined by default-typed symbol
arr2 = sp.Symbol("arr2")
......@@ -378,6 +379,7 @@ def test_array_declarations():
decl = typify(decl)
assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3))
# Array type determined by pre-typed symbol
q = TypedSymbol("q", Fp(16))
......@@ -386,6 +388,14 @@ def test_array_declarations():
decl = typify(decl)
assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2))
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2))
# Array type determined by LHS symbol
arr4 = TypedSymbol("arr4", Arr(Int(16), 4))
decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2)))
decl = typify(decl)
assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4)
def test_erronous_typing():
......
from functools import reduce
from operator import add
from pystencils import fields, Assignment, make_slice, Field
from pystencils import fields, Assignment, make_slice, Field, FieldType
from pystencils.types import PsStructType, create_type
from pystencils.backend.memory import BufferBasePtr
......@@ -12,6 +12,7 @@ from pystencils.backend.kernelcreation import (
)
from pystencils.backend.transformations import LowerToC
from pystencils.backend.ast import dfs_preorder
from pystencils.backend.ast.expressions import (
PsBufferAcc,
PsMemAcc,
......@@ -90,8 +91,9 @@ def test_lower_anonymous_structs():
]
)
sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype)
f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM)
asm = Assignment(sfield.center("val"), 31.2)
asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,)))
fasm = factory.parse_sympy(asm)
......@@ -102,6 +104,13 @@ def test_lower_anonymous_structs():
lowered_fasm = lower(fasm.clone())
assert isinstance(lowered_fasm, PsAssignment)
# Check type of sfield data pointer
for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)):
if expr.symbol.name == sbuf.base_pointer.name:
assert expr.symbol.dtype == create_type("uint8_t * restrict")
# Check LHS
assert isinstance(lowered_fasm.lhs, PsMemAcc)
assert isinstance(lowered_fasm.lhs.pointer, PsCast)
assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment