Skip to content
Snippets Groups Projects
Commit 0cec1fa8 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

various fixes to bugs encountered during waLBerla integration

parent 51a47fb7
No related merge requests found
Pipeline #63919 failed with stages
in 2 minutes and 54 seconds
...@@ -25,6 +25,8 @@ class FieldsInKernel: ...@@ -25,6 +25,8 @@ class FieldsInKernel:
self.custom_fields: set[Field] = set() self.custom_fields: set[Field] = set()
self.buffer_fields: set[Field] = set() self.buffer_fields: set[Field] = set()
self.archetype_field: Field | None = None
def __iter__(self) -> Iterator: def __iter__(self) -> Iterator:
return chain( return chain(
self.domain_fields, self.domain_fields,
......
...@@ -182,6 +182,10 @@ class FreezeExpressions: ...@@ -182,6 +182,10 @@ class FreezeExpressions:
def map_Integer(self, expr: sp.Integer) -> PsConstantExpr: def map_Integer(self, expr: sp.Integer) -> PsConstantExpr:
value = int(expr) value = int(expr)
return PsConstantExpr(PsConstant(value)) return PsConstantExpr(PsConstant(value))
def map_Float(self, expr: sp.Float) -> PsConstantExpr:
value = float(expr) # TODO: check accuracy of evaluation
return PsConstantExpr(PsConstant(value))
def map_Rational(self, expr: sp.Rational) -> PsExpression: def map_Rational(self, expr: sp.Rational) -> PsExpression:
num = PsConstantExpr(PsConstant(expr.numerator)) num = PsConstantExpr(PsConstant(expr.numerator))
......
...@@ -114,11 +114,7 @@ class FullIterationSpace(IterationSpace): ...@@ -114,11 +114,7 @@ class FullIterationSpace(IterationSpace):
) )
] ]
# Determine loop order by permuting dimensions return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
loop_order = archetype_field.layout
dimensions = [dimensions[coordinate] for coordinate in loop_order]
return FullIterationSpace(ctx, dimensions)
@staticmethod @staticmethod
def create_from_slice( def create_from_slice(
...@@ -176,18 +172,21 @@ class FullIterationSpace(IterationSpace): ...@@ -176,18 +172,21 @@ class FullIterationSpace(IterationSpace):
) )
] ]
# Determine loop order by permuting dimensions return FullIterationSpace(ctx, dimensions, archetype_field=archetype_field)
loop_order = archetype_field.layout
dimensions = [dimensions[coordinate] for coordinate in loop_order]
return FullIterationSpace(ctx, dimensions)
def __init__(self, ctx: KernelCreationContext, dimensions: Sequence[Dimension]): def __init__(
self,
ctx: KernelCreationContext,
dimensions: Sequence[Dimension],
archetype_field: Field | None = None,
):
super().__init__(tuple(dim.counter for dim in dimensions)) super().__init__(tuple(dim.counter for dim in dimensions))
self._ctx = ctx self._ctx = ctx
self._dimensions = dimensions self._dimensions = dimensions
self._archetype_field = archetype_field
@property @property
def dimensions(self): def dimensions(self):
return self._dimensions return self._dimensions
...@@ -204,6 +203,10 @@ class FullIterationSpace(IterationSpace): ...@@ -204,6 +203,10 @@ class FullIterationSpace(IterationSpace):
def steps(self): def steps(self):
return (dim.step for dim in self._dimensions) return (dim.step for dim in self._dimensions)
@property
def archetype_field(self) -> Field | None:
return self._archetype_field
def actual_iterations(self, dimension: int | None = None) -> PsExpression: def actual_iterations(self, dimension: int | None = None) -> PsExpression:
if dimension is None: if dimension is None:
return reduce( return reduce(
......
...@@ -6,7 +6,7 @@ from .platform import Platform ...@@ -6,7 +6,7 @@ from .platform import Platform
from ..kernelcreation.iteration_space import ( from ..kernelcreation.iteration_space import (
IterationSpace, IterationSpace,
FullIterationSpace, FullIterationSpace,
SparseIterationSpace, SparseIterationSpace
) )
from ..constants import PsConstant from ..constants import PsConstant
...@@ -43,7 +43,15 @@ class GenericCpu(Platform): ...@@ -43,7 +43,15 @@ class GenericCpu(Platform):
def _create_domain_loops( def _create_domain_loops(
self, body: PsBlock, ispace: FullIterationSpace self, body: PsBlock, ispace: FullIterationSpace
) -> PsBlock: ) -> PsBlock:
dimensions = ispace.dimensions dimensions = ispace.dimensions
# Determine loop order by permuting dimensions
archetype_field = ispace.archetype_field
if archetype_field is not None:
loop_order = archetype_field.layout
dimensions = [dimensions[coordinate] for coordinate in loop_order]
outer_block = body outer_block = body
for dimension in dimensions[::-1]: for dimension in dimensions[::-1]:
......
...@@ -15,7 +15,7 @@ from .basic_types import ( ...@@ -15,7 +15,7 @@ from .basic_types import (
deconstify, deconstify,
) )
from .quick import create_type, create_numeric_type from .quick import UserTypeSpec, create_type, create_numeric_type
from .exception import PsTypeError from .exception import PsTypeError
...@@ -34,6 +34,7 @@ __all__ = [ ...@@ -34,6 +34,7 @@ __all__ = [
"PsIeeeFloatType", "PsIeeeFloatType",
"constify", "constify",
"deconstify", "deconstify",
"UserTypeSpec",
"create_type", "create_type",
"create_numeric_type", "create_numeric_type",
"PsTypeError", "PsTypeError",
......
...@@ -463,6 +463,18 @@ class PsBoolType(PsScalarType): ...@@ -463,6 +463,18 @@ class PsBoolType(PsScalarType):
return np.False_ return np.False_
else: else:
raise PsTypeError(f"Cannot create boolean constant from value {value}") raise PsTypeError(f"Cannot create boolean constant from value {value}")
def c_string(self) -> str:
return "bool"
def __eq__(self, other: object) -> bool:
if not isinstance(other, PsBoolType):
return False
return self._base_equal(other)
def __hash__(self) -> int:
return hash(("PsBoolType", self._const))
class PsIntegerType(PsScalarType, ABC): class PsIntegerType(PsScalarType, ABC):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment