Skip to content
Snippets Groups Projects

Refactor Field Indexing Symbols

Files

@@ -2,17 +2,23 @@ from __future__ import annotations
@@ -2,17 +2,23 @@ from __future__ import annotations
from typing import Iterable, Iterator, Any
from typing import Iterable, Iterator, Any
from itertools import chain, count
from itertools import chain, count
from types import EllipsisType
from collections import namedtuple, defaultdict
from collections import namedtuple, defaultdict
import re
import re
from ...defaults import DEFAULTS
from ...defaults import DEFAULTS
from ...field import Field, FieldType
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol
from ...sympyextensions.typed_sympy import TypedSymbol, DynamicType
from ..symbols import PsSymbol
from ..symbols import PsSymbol
from ..arrays import PsLinearizedArray
from ..arrays import PsLinearizedArray
from ...types import PsType, PsIntegerType, PsNumericType, PsScalarType, PsStructType, deconstify
from ...types import (
 
PsType,
 
PsIntegerType,
 
PsNumericType,
 
PsScalarType,
 
PsStructType,
 
deconstify,
 
)
from ..constraints import KernelParamsConstraint
from ..constraints import KernelParamsConstraint
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
@@ -97,7 +103,7 @@ class KernelCreationContext:
@@ -97,7 +103,7 @@ class KernelCreationContext:
@property
@property
def constraints(self) -> tuple[KernelParamsConstraint, ...]:
def constraints(self) -> tuple[KernelParamsConstraint, ...]:
return tuple(self._constraints)
return tuple(self._constraints)
@property
@property
def metadata(self) -> dict[str, Any]:
def metadata(self) -> dict[str, Any]:
return self._metadata
return self._metadata
@@ -215,8 +221,27 @@ class KernelCreationContext:
@@ -215,8 +221,27 @@ class KernelCreationContext:
else:
else:
return
return
arr_shape: list[EllipsisType | int] | None = None
arr_shape: list[str | int] | None = None
arr_strides: list[EllipsisType | int] | None = None
arr_strides: list[str | int] | None = None
 
 
def normalize_type(s: TypedSymbol) -> PsIntegerType:
 
match s.dtype:
 
case DynamicType.INDEX_TYPE:
 
return self.index_dtype
 
case DynamicType.NUMERIC_TYPE:
 
if isinstance(self.default_dtype, PsIntegerType):
 
return self.default_dtype
 
else:
 
raise KernelConstraintsError(
 
f"Cannot use non-integer default numeric type {self.default_dtype} "
 
f"in field indexing symbol {s}."
 
)
 
case PsIntegerType():
 
return deconstify(s.dtype)
 
case _:
 
raise KernelConstraintsError(
 
f"Invalid data type for field indexing symbol {s}: {s.dtype}"
 
)
# Check field constraints and add to collection
# Check field constraints and add to collection
match field.field_type:
match field.field_type:
@@ -243,7 +268,15 @@ class KernelCreationContext:
@@ -243,7 +268,15 @@ class KernelCreationContext:
"Buffer fields cannot have variable index shape."
"Buffer fields cannot have variable index shape."
)
)
arr_shape = [..., num_entries]
buffer_len = field.spatial_shape[0]
 
 
if isinstance(buffer_len, TypedSymbol):
 
idx_type = normalize_type(buffer_len)
 
arr_shape = [buffer_len.name, num_entries]
 
else:
 
idx_type = DEFAULTS.index_dtype
 
arr_shape = [buffer_len, num_entries]
 
arr_strides = [num_entries, 1]
arr_strides = [num_entries, 1]
self._fields_collection.buffer_fields.add(field)
self._fields_collection.buffer_fields.add(field)
@@ -265,18 +298,25 @@ class KernelCreationContext:
@@ -265,18 +298,25 @@ class KernelCreationContext:
# For non-buffer fields, determine shape and strides
# For non-buffer fields, determine shape and strides
if arr_shape is None:
if arr_shape is None:
 
idx_types = set(
 
normalize_type(s)
 
for s in chain(field.shape, field.strides)
 
if isinstance(s, TypedSymbol)
 
)
 
 
if len(idx_types) > 1:
 
raise KernelConstraintsError(
 
f"Multiple incompatible types found in index symbols of field {field}: "
 
f"{idx_types}"
 
)
 
idx_type = idx_types.pop() if len(idx_types) > 0 else self.index_dtype
 
arr_shape = [
arr_shape = [
(
(s.name if isinstance(s, TypedSymbol) else s) for s in field.shape
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
]
]
arr_strides = [
arr_strides = [
(
(s.name if isinstance(s, TypedSymbol) else s) for s in field.strides
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
]
]
# The frontend doesn't quite agree with itself on how to model
# The frontend doesn't quite agree with itself on how to model
@@ -288,12 +328,13 @@ class KernelCreationContext:
@@ -288,12 +328,13 @@ class KernelCreationContext:
# Add array
# Add array
assert arr_strides is not None
assert arr_strides is not None
 
assert idx_type is not None
assert isinstance(field.dtype, (PsScalarType, PsStructType))
assert isinstance(field.dtype, (PsScalarType, PsStructType))
element_type = field.dtype
element_type = field.dtype
arr = PsLinearizedArray(
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
field.name, element_type, arr_shape, arr_strides, idx_type
)
)
self._fields_and_arrays[field.name] = FieldArrayPair(field, arr)
self._fields_and_arrays[field.name] = FieldArrayPair(field, arr)
Loading