diff --git a/src/pystencils/backend/transformations/lower_to_c.py b/src/pystencils/backend/transformations/lower_to_c.py index 8fd6f89ba967b211dfd80812aafee08ea5dd1f76..7fd176d18eb3cbad53000df90ed17a69ca78f4b6 100644 --- a/src/pystencils/backend/transformations/lower_to_c.py +++ b/src/pystencils/backend/transformations/lower_to_c.py @@ -18,7 +18,7 @@ from ..ast.expressions import ( PsCast, PsSymbolExpr, ) -from ...types import PsStructType, PsPointerType +from ...types import PsStructType, PsPointerType, PsUnsignedIntegerType class LowerToC: @@ -105,10 +105,20 @@ class LowerToC: assert isinstance(aggr.pointer, PsSymbolExpr) bp = aggr.pointer.symbol + bp_type = bp.get_dtype() + assert isinstance(bp_type, PsPointerType) # Need to keep track of base pointers already seen, since symbols must be unique if bp not in self._substitutions: - type_erased_bp = PsSymbol(bp.name, bp.dtype) + erased_type = PsPointerType( + PsUnsignedIntegerType(8, const=bp_type.base_type.const), + const=bp_type.const, + restrict=bp_type.restrict, + ) + type_erased_bp = PsSymbol( + bp.name, + erased_type + ) type_erased_bp.add_property(BufferBasePtr(buf)) self._substitutions[bp] = type_erased_bp else: diff --git a/tests/nbackend/transformations/test_lower_to_c.py b/tests/nbackend/transformations/test_lower_to_c.py index 1c8a2c8f320fe9ae6d232acc2c98236153f00d96..e7e0dec1de1d2fa05415557f95addafc79c89637 100644 --- a/tests/nbackend/transformations/test_lower_to_c.py +++ b/tests/nbackend/transformations/test_lower_to_c.py @@ -18,7 +18,8 @@ from pystencils.backend.ast.expressions import ( PsSymbolExpr, PsExpression, PsLookup, - PsAddressOf + PsAddressOf, + PsCast, ) from pystencils.backend.ast.structural import PsAssignment @@ -89,7 +90,7 @@ def test_lower_anonymous_structs(): ] ) sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype) - + asm = Assignment(sfield.center("val"), 31.2) fasm = factory.parse_sympy(asm) @@ -102,14 +103,11 @@ def test_lower_anonymous_structs(): lowered_fasm = lower(fasm.clone()) assert isinstance(lowered_fasm, PsAssignment) assert isinstance(lowered_fasm.lhs, PsMemAcc) - assert isinstance( - lowered_fasm.lhs.pointer, PsAddressOf - ) - assert isinstance( - lowered_fasm.lhs.pointer.operand, PsMemAcc - ) - type_erased_pointer = lowered_fasm.lhs.pointer.operand.pointer - + assert isinstance(lowered_fasm.lhs.pointer, PsCast) + assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf) + assert isinstance(lowered_fasm.lhs.pointer.operand.operand, PsMemAcc) + type_erased_pointer = lowered_fasm.lhs.pointer.operand.operand.pointer + assert isinstance(type_erased_pointer, PsSymbolExpr) assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties - assert type_erased_pointer.symbol.dtype == create_type("restrict uint8_t *") + assert type_erased_pointer.symbol.dtype == create_type("uint8_t * restrict")