Skip to content
Snippets Groups Projects
Commit 51e38626 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

test anon struct lowering

parent 035f7c41
No related branches found
No related tags found
1 merge request!421Refactor Field Modelling
......@@ -18,7 +18,7 @@ from ..ast.expressions import (
PsCast,
PsSymbolExpr,
)
from ...types import PsStructType, PsPointerType
from ...types import PsStructType, PsPointerType, PsUnsignedIntegerType
class LowerToC:
......@@ -105,10 +105,20 @@ class LowerToC:
assert isinstance(aggr.pointer, PsSymbolExpr)
bp = aggr.pointer.symbol
bp_type = bp.get_dtype()
assert isinstance(bp_type, PsPointerType)
# Need to keep track of base pointers already seen, since symbols must be unique
if bp not in self._substitutions:
type_erased_bp = PsSymbol(bp.name, bp.dtype)
erased_type = PsPointerType(
PsUnsignedIntegerType(8, const=bp_type.base_type.const),
const=bp_type.const,
restrict=bp_type.restrict,
)
type_erased_bp = PsSymbol(
bp.name,
erased_type
)
type_erased_bp.add_property(BufferBasePtr(buf))
self._substitutions[bp] = type_erased_bp
else:
......
......@@ -18,7 +18,8 @@ from pystencils.backend.ast.expressions import (
PsSymbolExpr,
PsExpression,
PsLookup,
PsAddressOf
PsAddressOf,
PsCast,
)
from pystencils.backend.ast.structural import PsAssignment
......@@ -89,7 +90,7 @@ def test_lower_anonymous_structs():
]
)
sfield = Field.create_generic("s", spatial_dimensions=1, dtype=stype)
asm = Assignment(sfield.center("val"), 31.2)
fasm = factory.parse_sympy(asm)
......@@ -102,14 +103,11 @@ def test_lower_anonymous_structs():
lowered_fasm = lower(fasm.clone())
assert isinstance(lowered_fasm, PsAssignment)
assert isinstance(lowered_fasm.lhs, PsMemAcc)
assert isinstance(
lowered_fasm.lhs.pointer, PsAddressOf
)
assert isinstance(
lowered_fasm.lhs.pointer.operand, PsMemAcc
)
type_erased_pointer = lowered_fasm.lhs.pointer.operand.pointer
assert isinstance(lowered_fasm.lhs.pointer, PsCast)
assert isinstance(lowered_fasm.lhs.pointer.operand, PsAddressOf)
assert isinstance(lowered_fasm.lhs.pointer.operand.operand, PsMemAcc)
type_erased_pointer = lowered_fasm.lhs.pointer.operand.operand.pointer
assert isinstance(type_erased_pointer, PsSymbolExpr)
assert BufferBasePtr(sbuf) in type_erased_pointer.symbol.properties
assert type_erased_pointer.symbol.dtype == create_type("restrict uint8_t *")
assert type_erased_pointer.symbol.dtype == create_type("uint8_t * restrict")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment