diff --git a/src/pystencils/backend/ast/expressions.py b/src/pystencils/backend/ast/expressions.py index 32d06b6332d8487b7c57e6d190cb6e8708f84110..d73b1faa758f8ce31312c674712ec89bfd5683ab 100644 --- a/src/pystencils/backend/ast/expressions.py +++ b/src/pystencils/backend/ast/expressions.py @@ -451,6 +451,9 @@ class PsLookup(PsExpression, PsLvalue): idx = [0][idx] self._aggregate = failing_cast(PsExpression, c) + def __repr__(self) -> str: + return f"PsLookup({repr(self._aggregate)}, {repr(self._member_name)})" + class PsCall(PsExpression): __match_args__ = ("function", "args") diff --git a/src/pystencils/backend/kernelcreation/typification.py b/src/pystencils/backend/kernelcreation/typification.py index 04aeddabe32910b6a7c084272f1302dc87cf78ca..c8fad68f106dea78126be9d9ada51e2c57180cd2 100644 --- a/src/pystencils/backend/kernelcreation/typification.py +++ b/src/pystencils/backend/kernelcreation/typification.py @@ -344,6 +344,8 @@ class Typifier: decl_tc.apply_dtype( PsArrayType(items_tc.target_type, rhs.shape), rhs ) + else: + decl_tc.infer_dtype(rhs) case PsDeclaration(lhs, rhs) | PsAssignment(lhs, rhs): # Only if the LHS is an untyped symbol, infer its type from the RHS diff --git a/src/pystencils/backend/memory.py b/src/pystencils/backend/memory.py index 6594cafbdb494c94fa976e66f5a71b33959a81f7..ad28cd3c73f2c9f8f7f18865d20ee1bff5222999 100644 --- a/src/pystencils/backend/memory.py +++ b/src/pystencils/backend/memory.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import ClassVar, Sequence +from typing import Sequence from itertools import chain from dataclasses import dataclass @@ -86,17 +86,15 @@ class PsSymbol: return f"{self._name}: {dtype_str}" def __repr__(self) -> str: - return f"PsSymbol({self._name}, {self._dtype})" + return f"PsSymbol({repr(self._name)}, {repr(self._dtype)})" @dataclass(frozen=True) -class BufferBasePtr(PsSymbolProperty): +class BufferBasePtr(UniqueSymbolProperty): """Symbol acts as a base pointer to a buffer.""" buffer: PsBuffer - _unique: ClassVar[bool] = True - class PsBuffer: """N-dimensional contiguous linearized buffer in heap memory. diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py index 7fd176d18eb3cbad53000df90ed17a69ca78f4b6..ea832355bb1a53f94fc07cad670f86f98e5f6a2e 100644 --- a/src/pystencils/backend/transformations/lower_to_c.py +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -65,7 +65,7 @@ class LowerToC: return i summands: list[PsExpression] = [ - maybe_cast(idx) * PsExpression.make(stride) + maybe_cast(cast(PsExpression, self.visit(idx))) * PsExpression.make(stride) for idx, stride in zip(indices, buf.strides, strict=True) ] diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 4ebafece72c8558a913b4edefaa74d3b5b554614..57d97238998050944e43146deba54fcc5dca5abb 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -22,6 +22,7 @@ from .backend.transformations import ( EliminateConstants, LowerToC, SelectFunctions, + CanonicalizeSymbols, ) from .backend.kernelfunction import ( create_cpu_kernel_function, @@ -143,12 +144,17 @@ def create_kernel( kernel_ast = optimize_cpu(ctx, platform, kernel_ast, config.cpu_optim) - erase_anons = LowerToC(ctx) - kernel_ast = cast(PsBlock, erase_anons(kernel_ast)) + # Lowering + lower_to_c = LowerToC(ctx) + kernel_ast = cast(PsBlock, lower_to_c(kernel_ast)) select_functions = SelectFunctions(platform) kernel_ast = cast(PsBlock, select_functions(kernel_ast)) + # Lowering introduces new symbols, which have to be canonicalized + canonicalize = CanonicalizeSymbols(ctx, True) + kernel_ast = cast(PsBlock, canonicalize(kernel_ast)) + if config.target.is_cpu(): return create_cpu_kernel_function( ctx, diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index e6fc4bb78ddfe18a5ac572700ec7d59d97fd84cf..d3d18720cf1ff3c4af14f6c276da52098adfbdd2 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -91,7 +91,8 @@ class PsPointerType(PsDereferencableType): def c_string(self) -> str: base_str = self._base_type.c_string() restrict_str = " RESTRICT" if self._restrict else "" - return f"{base_str} *{restrict_str} {self._const_string()}" + const_str = " const" if self.const else "" + return f"{base_str} *{restrict_str}{const_str}" def __repr__(self) -> str: return f"PsPointerType( {repr(self.base_type)}, const={self.const}, restrict={self.restrict} )" diff --git a/tests/nbackend/kernelcreation/test_typification.py b/tests/nbackend/kernelcreation/test_typification.py index 5ea2aa15e5230a24afb28662813123f39e2882cf..988fa4bb8b10c2c243abfd3a171657ad6bf5e418 100644 --- a/tests/nbackend/kernelcreation/test_typification.py +++ b/tests/nbackend/kernelcreation/test_typification.py @@ -371,6 +371,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr1").dtype == Arr(Fp(32), (4,)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (4,)) # Array type determined by default-typed symbol arr2 = sp.Symbol("arr2") @@ -378,6 +379,7 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr2").dtype == Arr(Fp(32), (2, 3)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(32), (2, 3)) # Array type determined by pre-typed symbol q = TypedSymbol("q", Fp(16)) @@ -386,6 +388,14 @@ def test_array_declarations(): decl = typify(decl) assert ctx.get_symbol("arr3").dtype == Arr(Fp(16), (2, 2)) + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Fp(16), (2, 2)) + + # Array type determined by LHS symbol + arr4 = TypedSymbol("arr4", Arr(Int(16), 4)) + decl = freeze(Assignment(arr4, sp.Tuple(11, 1, 4, 2))) + decl = typify(decl) + + assert decl.lhs.dtype == decl.rhs.dtype == Arr(Int(16), 4) def test_erronous_typing(): diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py index e7e0dec1de1d2fa05415557f95addafc79c89637..b557a7493f9a84cb13b511e8fca1f898823bc9bb 100644 --- a/tests/nbackend/transformations/test_lower_to_c.py +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -1,7 +1,7 @@ from functools import reduce from operator import add -from pystencils import fields, Assignment, make_slice, Field +from pystencils import fields, Assignment, make_slice, Field, FieldType from pystencils.types import PsStructType, create_type from pystencils.backend.memory import BufferBasePtr @@ -12,6 +12,7 @@ from pystencils.backend.kernelcreation import ( ) from pystencils.backend.transformations import LowerToC +from pystencils.backend.ast import dfs_preorder from pystencils.backend.ast.expressions import ( PsBufferAcc, PsMemAcc, @@ -90,8 +91,9 @@ def test_lower_anonymous_structs(): ] ) sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) + f = Field.create_generic("f", 1, ctx.default_dtype, field_type=FieldType.CUSTOM) - asm = Assignment(sfield.center("val"), 31.2) + asm = Assignment(sfield.center("val"), f.absolute_access((sfield.center("x"),), (0,))) fasm = factory.parse_sympy(asm) @@ -102,6 +104,13 @@ def test_lower_anonymous_structs(): lowered_fasm = lower(fasm.clone()) assert isinstance(lowered_fasm, PsAssignment) + + # Check type of sfield data pointer + for expr in dfs_preorder(lowered_fasm, lambda n: isinstance(n, PsSymbolExpr)): + if expr.symbol.name == sbuf.base_pointer.name: + assert expr.symbol.dtype == create_type("uint8_t * restrict") + + # Check LHS assert isinstance(lowered_fasm.lhs, PsMemAcc) assert isinstance(lowered_fasm.lhs.pointer, PsCast) assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf)