From a04c241bc996eb82b4e1a86383adc52146282eaf Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 22 Oct 2024 16:42:32 +0200 Subject: [PATCH] A series of bug fixes - Add another canonicalize pass after lowering - Fix typification of array decl RHS - Fix: LowerToC did not descend into PsBufferAcc index args --- src/pystencils/backend/ast/expressions.py | 3 +++ .../backend/kernelcreation/typification.py | 2 ++ src/pystencils/backend/memory.py | 8 +++----- .../backend/transformations/lower_to_c.py | 2 +- src/pystencils/kernelcreation.py | 10 ++++++++-- src/pystencils/types/types.py | 3 ++- tests/nbackend/kernelcreation/test_typification.py | 10 ++++++++++ tests/nbackend/transformations/test_lower_to_c.py | 13 +++++++++++-- 8 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 32d06b633..d73b1faa7 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -451,6 +451,9 @@ class PsLookup(PsExpression, PsLvalue): idx = [0][idx] self._aggregate = failing_cast(PsExpression, c) + def __repr__(self) -> str: + return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})" + class PsCall(PsExpression): __match_args__ = ("function", "args") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 04aeddabe..c8fad68f1 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -344,6 +344,8 @@ class Typifier: decl_tc.apply_dtype( PsArrayType(items_tc.target_type, rhs.shape), rhs ) + else: + decl_tc.infer_dtype(rhs) case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs): # Only if the LHS is an untyped symbol, infer its type from the RHS diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py index 6594cafbd..ad28cd3c7 100644 --- a/src/pystencils/backend/memory.py +++ b/src/pystencils/backend/memory.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import ClassVar, Sequence +from typing import Sequence from itertools import chain from dataclasses import dataclass @@ -86,17 +86,15 @@ class PsSymbol: return f"{self._name}: {dtype_str}" def __repr__(self) -> str: - return f"PsSymbol({self._name}, {self._dtype})" + return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})" @dataclass(frozen=True) -class BufferBasePtr(PsSymbolProperty): +class BufferBasePtr(UniqueSymbolProperty): """Symbol acts as a base pointer to a buffer.""" buffer: PsBuffer - _unique: ClassVar[bool] = True - class PsBuffer: """N-dimensional contiguous linearized buffer in heap memory. diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py index 7fd176d18..ea832355b 100644 --- a/src/pystencils/backend/transformations/lower_to_c.py +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -65,7 +65,7 @@ class LowerToC: return i summands: list[PsExpression] = [ - maybe_cast(idx) * PsExpression.make(stride) + maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride) for idx, stride in zip(indices, buf.strides, strict=True) ] diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 4ebafece7..57d972389 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -22,6 +22,7 @@ from .backend.transformations import ( EliminateConstants, LowerToC, SelectFunctions, + CanonicalizeSymbols, ) from .backend.kernelfunction import ( create_cpu_kernel_function, @@ -143,12 +144,17 @@ def create_kernel( kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) - erase_anons = LowerToC(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) + # Lowering + lower_to_c = LowerToC(ctx) + kernel_ast = cast(PsBlock, lower_to_c(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + # Lowering introduces new symbols, which have to be canonicalized + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + if config.target.is_cpu(): return create_cpu_kernel_function( ctx, diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index e6fc4bb78..d3d18720c 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -91,7 +91,8 @@ class PsPointerType(PsDereferencableType): def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" - return f"{base_str} *{restrict_str} {self._const_string()}" + const_str = " const" if self.const else "" + return f"{base_str} *{restrict_str}{const_str}" def __repr__(self) -> str: return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5ea2aa15e..988fa4bb8 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -371,6 +371,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,)) # Array type determined by default-typed symbol arr2 = sp.Symbol("arr2") @@ -378,6 +379,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3)) # Array type determined by pre-typed symbol q = TypedSymbol("q", Fp(16)) @@ -386,6 +388,14 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2)) + + # Array type determined by LHS symbol + arr4 = TypedSymbol("arr4", Arr(Int(16), 4)) + decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2))) + decl = typify(decl) + + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4) def test_erronous_typing(): diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py index e7e0dec1d..b557a7493 100644 --- a/tests/nbackend/transformations/test_lower_to_c.py +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -1,7 +1,7 @@ from functools import reduce from operator import add -from pystencils import fields, Assignment, make_slice, Field +from pystencils import fields, Assignment, make_slice, Field, FieldType from pystencils.types import PsStructType, create_type from pystencils.backend.memory import BufferBasePtr @@ -12,6 +12,7 @@ from pystencils.backend.kernelcreation import ( ) from pystencils.backend.transformations import LowerToC +from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast.expressions import ( PsBufferAcc, PsMemAcc, @@ -90,8 +91,9 @@ def test_lower_anonymous_structs(): ] ) sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) + f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM) - asm = Assignment(sfield.center("val"), 31.2) + asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,))) fasm = factory.parse_sympy(asm) @@ -102,6 +104,13 @@ def test_lower_anonymous_structs(): lowered_fasm = lower(fasm.clone()) assert isinstance(lowered_fasm, PsAssignment) + + # Check type of sfield data pointer + for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)): + if expr.symbol.name == sbuf.base_pointer.name: + assert expr.symbol.dtype == create_type("uint8_t * restrict") + + # Check LHS assert isinstance(lowered_fasm.lhs, PsMemAcc) assert isinstance(lowered_fasm.lhs.pointer, PsCast) assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf) -- GitLab