Skip to content
Snippets Groups Projects

Refactor Field Indexing Symbols

Merged Frederik Hennig requested to merge fhennig/refactor-indexing-symbols into v2.0-dev
Files
13
@@ -2,17 +2,23 @@ from __future__ import annotations
from typing import Iterable, Iterator, Any
from itertools import chain, count
from types import EllipsisType
from collections import namedtuple, defaultdict
import re
from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ..symbols import PsSymbol
from ..arrays import PsLinearizedArray
from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify
from ...types import (
PsType,
PsIntegerType,
PsNumericType,
PsScalarType,
PsStructType,
deconstify,
)
from ..constraints import KernelParamsConstraint
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
@@ -97,7 +103,7 @@ class KernelCreationContext:
@property
def constraints(self) -> tuple[KernelParamsConstraint, ...]:
return tuple(self._constraints)
@property
def metadata(self) -> dict[str, Any]:
return self._metadata
@@ -215,8 +221,27 @@ class KernelCreationContext:
else:
return
arr_shape: list[EllipsisType | int] | None = None
arr_strides: list[EllipsisType | int] | None = None
arr_shape: list[str | int] | None = None
arr_strides: list[str | int] | None = None
def normalize_type(s: TypedSymbol) -> PsIntegerType:
match s.dtype:
case DynamicType.INDEX_TYPE:
return self.index_dtype
case DynamicType.NUMERIC_TYPE:
if isinstance(self.default_dtype, PsIntegerType):
return self.default_dtype
else:
raise KernelConstraintsError(
f"Cannot use non-integer default numeric type {self.default_dtype} "
f"in field indexing symbol {s}."
)
case PsIntegerType():
return deconstify(s.dtype)
case _:
raise KernelConstraintsError(
f"Invalid data type for field indexing symbol {s}: {s.dtype}"
)
# Check field constraints and add to collection
match field.field_type:
@@ -243,7 +268,15 @@ class KernelCreationContext:
"Buffer fields cannot have variable index shape."
)
arr_shape = [..., num_entries]
buffer_len = field.spatial_shape[0]
if isinstance(buffer_len, TypedSymbol):
idx_type = normalize_type(buffer_len)
arr_shape = [buffer_len.name, num_entries]
else:
idx_type = DEFAULTS.index_dtype
arr_shape = [buffer_len, num_entries]
arr_strides = [num_entries, 1]
self._fields_collection.buffer_fields.add(field)
@@ -265,18 +298,25 @@ class KernelCreationContext:
# For non-buffer fields, determine shape and strides
if arr_shape is None:
idx_types = set(
normalize_type(s)
for s in chain(field.shape, field.strides)
if isinstance(s, TypedSymbol)
)
if len(idx_types) > 1:
raise KernelConstraintsError(
f"Multiple incompatible types found in index symbols of field {field}: "
f"{idx_types}"
)
idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype
arr_shape = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
(s.name if isinstance(s, TypedSymbol) else s) for s in field.shape
]
arr_strides = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
(s.name if isinstance(s, TypedSymbol) else s) for s in field.strides
]
# The frontend doesn't quite agree with itself on how to model
@@ -288,12 +328,13 @@ class KernelCreationContext:
# Add array
assert arr_strides is not None
assert idx_type is not None
assert isinstance(field.dtype, (PsScalarType, PsStructType))
element_type = field.dtype
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
field.name, element_type, arr_shape, arr_strides, idx_type
)
self._fields_and_arrays[field.name] = FieldArrayPair(field, arr)
Loading