Skip to content
Snippets Groups Projects
Commit f721e167 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

literal printing and header collection

parent a3843b10
No related branches found
No related tags found
No related merge requests found
Pipeline #60466 canceled
from typing import cast
from typing import cast, Any
from functools import reduce
from pymbolic.primitives import Variable
from pymbolic.mapper import Collector
from pymbolic.mapper.dependency import DependencyMapper
from .kernelfunction import PsKernelFunction
from .nodes import PsAstNode, PsExpression, PsAssignment, PsDeclaration, PsLoop, PsBlock
from ..typed_expressions import PsTypedVariable
from ..typed_expressions import PsTypedVariable, PsTypedConstant
from ..exceptions import PsMalformedAstException, PsInternalCompilerError
......@@ -24,12 +27,12 @@ class UndefinedVariablesCollector:
include_cses=False,
)
def collect(self, node: PsAstNode) -> set[PsTypedVariable]:
def __call__(self, node: PsAstNode) -> set[PsTypedVariable]:
"""Returns all `PsTypedVariable`s that occur in the given AST without being defined prior to their usage."""
match node:
case PsKernelFunction(block):
return self.collect(block)
return self(block)
case PsExpression(expr):
variables: set[Variable] = self._pb_dep_mapper(expr)
......@@ -43,22 +46,22 @@ class UndefinedVariablesCollector:
return cast(set[PsTypedVariable], variables)
case PsAssignment(lhs, rhs):
return self.collect(lhs) | self.collect(rhs)
return self(lhs) | self(rhs)
case PsBlock(statements):
undefined_vars: set[PsTypedVariable] = set()
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt)
undefined_vars |= self(stmt)
return undefined_vars
case PsLoop(ctr, start, stop, step, body):
undefined_vars = (
self.collect(start)
| self.collect(stop)
| self.collect(step)
| self.collect(body)
self(start)
| self(stop)
| self(step)
| self(body)
)
undefined_vars.remove(ctr.symbol)
return undefined_vars
......@@ -82,3 +85,34 @@ class UndefinedVariablesCollector:
raise PsInternalCompilerError(
f"Don't know how to collect declared variables from {unknown}"
)
def collect_undefined_variables(node: PsAstNode) -> set[PsTypedVariable]:
return UndefinedVariablesCollector()(node)
class RequiredHeadersCollector(Collector):
"""Collect all header files required by a given AST for correct compilation.
Required headers can currently only be defined in subclasses of `PsAbstractType`.
"""
def __call__(self, node: PsAstNode) -> set[str]:
match node:
case PsExpression(expr):
return self.rec(expr)
case node:
return reduce(set.union, (self(c) for c in node.children()), set())
def map_typed_variable(self, var: PsTypedVariable) -> set[str]:
return var.dtype.required_headers
def map_constant(self, cst: Any):
if not isinstance(cst, PsTypedConstant):
raise PsMalformedAstException("Untyped constant encountered in expression.")
return cst.dtype.required_headers
def collect_required_headers(node: PsAstNode) -> set[str]:
return RequiredHeadersCollector()(node)
......@@ -129,9 +129,9 @@ class PsKernelFunction(PsAstNode):
This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary.
"""
from .analysis import UndefinedVariablesCollector
from .collectors import collect_undefined_variables
params_set = UndefinedVariablesCollector().collect(self)
params_set = collect_undefined_variables(self)
params_list = sorted(params_set, key=lambda p: p.name)
arrays = set(p.array for p in params_list if isinstance(p, PsArrayBasePointer))
......@@ -140,5 +140,6 @@ class PsKernelFunction(PsAstNode):
)
def get_required_headers(self) -> set[str]:
# TODO: Headers from types, vectorizer, ...
return set()
# To Do: Headers from target/instruction set/...
from .collectors import collect_required_headers
return collect_required_headers(self)
from __future__ import annotations
from typing import TypeAlias, Any
from sys import intern
import pymbolic.primitives as pb
......@@ -16,6 +17,7 @@ class PsTypedVariable(pb.Variable):
init_arg_names: tuple[str, ...] = ("name", "dtype")
__match_args__ = ("name", "dtype")
mapper_method = intern("map_typed_variable")
def __init__(self, name: str, dtype: PsAbstractType):
super(PsTypedVariable, self).__init__(name)
......@@ -98,8 +100,12 @@ class PsTypedConstant:
self._dtype = constify(dtype)
self._value = self._dtype.create_constant(value)
@property
def dtype(self) -> PsNumericType:
return self._dtype
def __str__(self) -> str:
return str(self._value)
return self._dtype.create_literal(self._value)
def __repr__(self) -> str:
return f"PsTypedConstant( {self._value}, {repr(self._dtype)} )"
......
......@@ -32,6 +32,15 @@ class PsAbstractType(ABC):
def const(self) -> bool:
return self._const
# -------------------------------------------------------------------------------------------
# Optional Info
# -------------------------------------------------------------------------------------------
@property
def required_headers(self) -> set[str]:
"""The set of header files required when this type occurs in generated code."""
return set()
# -------------------------------------------------------------------------------------------
# Internal virtual operations
# -------------------------------------------------------------------------------------------
......@@ -154,6 +163,14 @@ class PsNumericType(PsAbstractType, ABC):
PsTypeError: If the given value cannot be interpreted in this type.
"""
@abstractmethod
def create_literal(self, value: Any) -> str:
"""Create a C numerical literal for a constant of this type.
Raises:
PsTypeError: If the given value's type is not the numeric type's compiler-internal representation.
"""
@abstractmethod
def is_int(self) -> bool:
...
......@@ -185,7 +202,7 @@ class PsScalarType(PsNumericType, ABC):
def is_float(self) -> bool:
return isinstance(self, PsIeeeFloatType)
@property
@abstractmethod
def itemsize(self) -> int:
......@@ -202,6 +219,7 @@ class PsIntegerType(PsScalarType, ABC):
__match_args__ = ("width",)
SUPPORTED_WIDTHS = (8, 16, 32, 64)
NUMPY_TYPES: dict[int, type] = dict()
def __init__(self, width: int, signed: bool = True, const: bool = False):
if width not in self.SUPPORTED_WIDTHS:
......@@ -221,11 +239,19 @@ class PsIntegerType(PsScalarType, ABC):
@property
def signed(self) -> bool:
return self._signed
@property
def itemsize(self) -> int:
return self.width // 8
def create_literal(self, value: Any) -> str:
np_dtype = self.NUMPY_TYPES[self._width]
if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
unsigned_suffix = "" if self.signed else "u"
# TODO: cast literal to correct type?
return str(value) + unsigned_suffix
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsIntegerType):
return False
......@@ -329,11 +355,29 @@ class PsIeeeFloatType(PsScalarType):
@property
def width(self) -> int:
return self._width
@property
def itemsize(self) -> int:
return self.width // 8
@property
def required_headers(self) -> set[str]:
if self._width == 16:
return {'"half_precision.h"'}
else:
return set()
def create_literal(self, value: Any) -> str:
np_dtype = self.NUMPY_TYPES[self._width]
if not isinstance(value, np_dtype):
raise PsTypeError(f"Given value {value} is not of required type {np_dtype}")
match self.width:
case 16: return f"((half) {value})" # see include/half_precision.h
case 32: return f"{value}f"
case 64: return str(value)
case _: assert False, "unreachable code"
def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment