Skip to content
Snippets Groups Projects
Commit a3843b10 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

slightly extended typifier

parent 03affb6c
No related branches found
No related tags found
No related merge requests found
Pipeline #60440 failed
......@@ -9,8 +9,11 @@ from ..ast import PsBlock
from .context import KernelCreationContext, FullIterationSpace
from .freeze import FreezeExpressions
from .typification import Typifier
# flake8: noqa
def create_domain_kernel(assignments: AssignmentCollection):
# TODO: Assemble configuration
......@@ -46,7 +49,8 @@ def create_domain_kernel(assignments: AssignmentCollection):
# 5. Typify
# Also the same for both types of kernels
# determine_types(kernel_body)
typify = Typifier(ctx)
kernel_body = typify(kernel_body)
# Up to this point, all was target-agnostic, but now the target becomes relevant.
# Here we might hand off the compilation to a target-specific part of the compiler
......
from __future__ import annotations
from typing import TypeVar
import pymbolic.primitives as pb
from pymbolic.mapper import Mapper
from .context import KernelCreationContext
from ..types import PsAbstractType
from ..typed_expressions import PsTypedVariable
from ..ast import PsAstNode, PsExpression, PsAssignment
class TypificationException(Exception):
"""Indicates a fatal error during typification."""
NodeT = TypeVar("NodeT", bound=PsAstNode)
class Typifier(Mapper):
def __init__(self, ctx: KernelCreationContext):
self._ctx = ctx
def __call__(self, expr: pb.Expression) -> tuple[pb.Expression, PsAbstractType]:
return self.rec(expr)
def __call__(self, node: NodeT) -> NodeT:
match node:
case PsExpression(expr):
node.expression, _ = self.rec(expr)
case PsAssignment(lhs, rhs):
lhs, lhs_dtype = self.rec(lhs)
rhs, rhs_dtype = self.rec(rhs)
if lhs_dtype != rhs_dtype:
# todo: (optional) automatic cast insertion?
raise TypificationException(
"Mismatched types in assignment: \n"
f" {lhs} <- {rhs}\n"
f" dtype(lhs) = {lhs_dtype}\n"
f" dtype(rhs) = {rhs_dtype}\n"
)
node.lhs = lhs
node.rhs = rhs
case unknown:
raise NotImplementedError(f"Don't know how to typify {unknown}")
return node
def map_variable(self, var: pb.Variable) -> tuple[pb.Expression, PsAbstractType]:
dtype = NotImplemented # determine variable type
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment