Skip to content
Snippets Groups Projects
Commit bddd3a37 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

refactor field and array handling in context

parent 0dbf7137
No related merge requests found
from __future__ import annotations
from types import EllipsisType
from ...field import Field, FieldType
from ...sympyextensions.typed_sympy import TypedSymbol, BasicType, StructType
from ..arrays import PsLinearizedArray
......@@ -43,15 +45,18 @@ class KernelCreationContext:
or full iteration space.
"""
def __init__(self,
default_dtype: PsNumericType = PbDefaults.numeric_dtype,
index_dtype: PsIntegerType = PbDefaults.index_dtype):
def __init__(
self,
default_dtype: PsNumericType = PbDefaults.numeric_dtype,
index_dtype: PsIntegerType = PbDefaults.index_dtype,
):
self._default_dtype = default_dtype
self._index_dtype = index_dtype
self._arrays: dict[Field, PsLinearizedArray] = dict()
self._constraints: list[PsKernelConstraint] = []
self._field_arrays: dict[Field, PsLinearizedArray] = dict()
self._fields_collection = FieldsInKernel()
self._ispace: IterationSpace | None = None
@property
......@@ -76,7 +81,22 @@ class KernelCreationContext:
return self._fields_collection
def add_field(self, field: Field):
"""Add the given field to the context's fields collection"""
"""Add the given field to the context's fields collection.
This method adds the passed ``field`` to the context's field collection, which is
accesible through the `fields` member, and creates an array representation of the field,
which is retrievable through `get_array`.
Before adding the field to the collection, various sanity and constraint checks are applied.
"""
if field in self._field_arrays:
# Field was already added
return
arr_shape: list[EllipsisType | int] | None = None
arr_strides: list[EllipsisType | int] | None = None
# Check field constraints and add to collection
match field.field_type:
case FieldType.GENERIC | FieldType.STAGGERED | FieldType.STAGGERED_FLUX:
self._fields_collection.domain_fields.add(field)
......@@ -87,6 +107,23 @@ class KernelCreationContext:
f"Invalid spatial shape of buffer field {field.name}: {field.spatial_dimensions}. "
"Buffer fields must be one-dimensional."
)
if field.index_dimensions > 1:
raise KernelConstraintsError(
f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. "
"Buffer fields can have at most one index dimension."
)
num_entries = field.index_shape[0] if field.index_shape else 1
if not isinstance(num_entries, int):
raise KernelConstraintsError(
f"Invalid index shape of buffer field {field.name}: {field.spatial_dimensions}. "
"Buffer fields cannot have variable index shape."
)
arr_shape = [..., num_entries]
arr_strides = [num_entries, 1]
self._fields_collection.buffer_fields.add(field)
case FieldType.INDEXED:
......@@ -103,48 +140,51 @@ class KernelCreationContext:
case _:
assert False, "unreachable code"
def get_array(self, field: Field) -> PsLinearizedArray:
if field not in self._arrays:
if field.field_type == FieldType.BUFFER:
# Buffers are always contiguous
assert field.spatial_dimensions == 1
assert field.index_dimensions <= 1
num_entries = field.index_shape[0] if field.index_shape else 1
# For non-buffer fields, determine shape and strides
arr_shape = [..., num_entries]
arr_strides = [num_entries, 1]
else:
arr_shape = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
]
arr_strides = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += [1]
arr_strides += [1]
assert isinstance(field.dtype, (BasicType, StructType))
element_type = make_type(field.dtype.numpy_dtype)
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
)
self._arrays[field] = arr
return self._arrays[field]
if arr_shape is None:
arr_shape = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.shape
]
arr_strides = [
(
Ellipsis if isinstance(s, TypedSymbol) else s
) # TODO: Field should also use ellipsis
for s in field.strides
]
# The frontend doesn't quite agree with itself on how to model
# fields with trivial index dimensions. Sometimes the index_shape is empty,
# sometimes its (1,). This is canonicalized here.
if not field.index_shape:
arr_shape += [1]
arr_strides += [1]
# Add array
assert arr_strides is not None
assert isinstance(field.dtype, (BasicType, StructType))
element_type = make_type(field.dtype.numpy_dtype)
arr = PsLinearizedArray(
field.name, element_type, arr_shape, arr_strides, self.index_dtype
)
self._field_arrays[field] = arr
def get_array(self, field: Field) -> PsLinearizedArray:
"""Retrieve the underlying array for a given field.
If the given field was not previously registered using `add_field`,
this method internally calls `add_field` to check the field for consistency.
"""
if field not in self._field_arrays:
self.add_field(field)
return self._field_arrays[field]
# Iteration Space
......
......@@ -2,7 +2,7 @@ import pytest
from pystencils.field import Field, FieldType
from pystencils.backend.types.quick import *
from pystencils.kernelcreation import (
from pystencils.config import (
CreateKernelConfig,
PsOptionsError,
)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment