diff --git a/pystencils/nbackend/ast/nodes.py b/pystencils/nbackend/ast/nodes.py index 7459de7aadae49948c46a2479f770400e36d8321..adb3e2a010489c1109ac855d28b22c4e5d2932df 100644 --- a/pystencils/nbackend/ast/nodes.py +++ b/pystencils/nbackend/ast/nodes.py @@ -9,6 +9,8 @@ from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue T = TypeVar("T") + + def failing_cast(target: type, obj: T): if not isinstance(obj, target): raise TypeError(f"Casting {obj} to {target} failed.") @@ -62,6 +64,7 @@ class PsBlock(PsAstNode): def set_child(self, idx: int, c: PsAstNode): self._children[idx] = c + class PsLeafNode(PsAstNode): def num_children(self) -> int: return 0 diff --git a/pystencils/nbackend/sympy_mapper.py b/pystencils/nbackend/sympy_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..380ed699d285141feb7028a05d2e1ddce1059fae --- /dev/null +++ b/pystencils/nbackend/sympy_mapper.py @@ -0,0 +1,50 @@ +from pymbolic.interop.sympy import SympyToPymbolicMapper + +from pystencils.typing import TypedSymbol +from pystencils.typing.typed_sympy import SHAPE_DTYPE +from .ast.nodes import PsAssignment, PsSymbolExpr +from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType +from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess + +CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)] + + +class PystencilsToPymbolicMapper(SympyToPymbolicMapper): + def map_Assignment(self, expr): # noqa + lhs = self.rec(expr.lhs) + rhs = self.rec(expr.rhs) + return PsAssignment(lhs, rhs) + + def map_BasicType(self, expr): + width = expr.numpy_dtype.itemsize * 8 + const = expr.const + if expr.is_float(): + return PsIeeeFloatType(width, const) + elif expr.is_uint(): + return PsUnsignedIntegerType(width, const) + elif expr.is_int(): + return PsSignedIntegerType(width, const) + else: + raise (NotImplementedError, "Not supported dtype") + + def map_FieldShapeSymbol(self, expr): + dtype = self.rec(expr.dtype) + return PsTypedVariable(expr.name, dtype) + + def map_TypedSymbol(self, expr): + dtype = self.rec(expr.dtype) + return PsTypedVariable(expr.name, dtype) + + def map_Access(self, expr): + name = expr.field.name + shape = tuple([self.rec(s) for s in expr.field.shape]) + strides = tuple([self.rec(s) for s in expr.field.strides]) + dtype = self.rec(expr.dtype) + + array = PsLinearizedArray(name, shape, strides, dtype) + + ptr = PsArrayBasePointer(expr.name, array) + index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)]) + index = self.rec(index) + + return PsSymbolExpr(PsArrayAccess(ptr, index)) diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 22b565f701b57a85eee8fb597af0fcd3d2c60323..5d0c0fc891150a16df6c9ecc0e140980e118edb4 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -89,9 +89,9 @@ class PsArrayAccess(pb.Subscript): def base_ptr(self): return self._base_ptr - @property - def index(self): - return self._index + # @property + # def index(self): + # return self._index @property def array(self) -> PsArray: diff --git a/pystencils/nbackend/types/basic_types.py b/pystencils/nbackend/types/basic_types.py index 698418e73e3ee5cca2f4f1b7cb31653f0f8da528..dcdb300ce1696fb5b61028ca47c7d632eea7904e 100644 --- a/pystencils/nbackend/types/basic_types.py +++ b/pystencils/nbackend/types/basic_types.py @@ -204,7 +204,7 @@ class PsIeeeFloatType(PsAbstractType): __match_args__ = ("width",) - SUPPORTED_WIDTHS = (32, 64) + SUPPORTED_WIDTHS = (16, 32, 64) def __init__(self, width: int, const: bool = False): if width not in self.SUPPORTED_WIDTHS: