Skip to content
Snippets Groups Projects
Commit e439a228 authored by Markus Holzer's avatar Markus Holzer
Browse files

Add pystencils to pymbolic mapper

parent d2520fd3
Branches
Tags
No related merge requests found
Pipeline #59688 failed with stages
in 2 minutes and 53 seconds
......@@ -9,6 +9,8 @@ from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue
T = TypeVar("T")
def failing_cast(target: type, obj: T):
if not isinstance(obj, target):
raise TypeError(f"Casting {obj} to {target} failed.")
......@@ -62,6 +64,7 @@ class PsBlock(PsAstNode):
def set_child(self, idx: int, c: PsAstNode):
self._children[idx] = c
class PsLeafNode(PsAstNode):
def num_children(self) -> int:
return 0
......
from pymbolic.interop.sympy import SympyToPymbolicMapper
from pystencils.typing import TypedSymbol
from pystencils.typing.typed_sympy import SHAPE_DTYPE
from .ast.nodes import PsAssignment, PsSymbolExpr
from .types import PsSignedIntegerType, PsIeeeFloatType, PsUnsignedIntegerType
from .typed_expressions import PsArrayBasePointer, PsLinearizedArray, PsTypedVariable, PsArrayAccess
CTR_SYMBOLS = [TypedSymbol(f"ctr_{i}", SHAPE_DTYPE) for i in range(3)]
class PystencilsToPymbolicMapper(SympyToPymbolicMapper):
def map_Assignment(self, expr): # noqa
lhs = self.rec(expr.lhs)
rhs = self.rec(expr.rhs)
return PsAssignment(lhs, rhs)
def map_BasicType(self, expr):
width = expr.numpy_dtype.itemsize * 8
const = expr.const
if expr.is_float():
return PsIeeeFloatType(width, const)
elif expr.is_uint():
return PsUnsignedIntegerType(width, const)
elif expr.is_int():
return PsSignedIntegerType(width, const)
else:
raise (NotImplementedError, "Not supported dtype")
def map_FieldShapeSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_TypedSymbol(self, expr):
dtype = self.rec(expr.dtype)
return PsTypedVariable(expr.name, dtype)
def map_Access(self, expr):
name = expr.field.name
shape = tuple([self.rec(s) for s in expr.field.shape])
strides = tuple([self.rec(s) for s in expr.field.strides])
dtype = self.rec(expr.dtype)
array = PsLinearizedArray(name, shape, strides, dtype)
ptr = PsArrayBasePointer(expr.name, array)
index = sum([ctr * stride for ctr, stride in zip(CTR_SYMBOLS, expr.field.strides)])
index = self.rec(index)
return PsSymbolExpr(PsArrayAccess(ptr, index))
......@@ -89,9 +89,9 @@ class PsArrayAccess(pb.Subscript):
def base_ptr(self):
return self._base_ptr
@property
def index(self):
return self._index
# @property
# def index(self):
# return self._index
@property
def array(self) -> PsArray:
......
......@@ -204,7 +204,7 @@ class PsIeeeFloatType(PsAbstractType):
__match_args__ = ("width",)
SUPPORTED_WIDTHS = (32, 64)
SUPPORTED_WIDTHS = (16, 32, 64)
def __init__(self, width: int, const: bool = False):
if width not in self.SUPPORTED_WIDTHS:
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment