diff --git a/.gitignore b/.gitignore index cb75f0cac251842b65cf4a66a34a8dd4171fd1e3..ef18ef29c682c0c471e5ecc4c974c7f9fe602763 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,6 @@ dist .coverage htmlcov coverage.xml + +# scratch +scratch \ No newline at end of file diff --git a/src/sfg_walberla/api.py b/src/sfg_walberla/api.py index b9f78ada79600cbeafe857990c0ed437e1acc28e..8e087b2530b717f741d1b3fabc85d1bdca18d822 100644 --- a/src/sfg_walberla/api.py +++ b/src/sfg_walberla/api.py @@ -8,7 +8,7 @@ from pystencils.types import ( PsPointerType, PsType, ) -from pystencilssfg.lang import IFieldExtraction, AugExpr, SrcField, SrcVector, Ref +from pystencilssfg.lang import IFieldExtraction, AugExpr, SrcField, SrcVector, Ref, ExprLike from pystencilssfg.ir import SfgHeaderInclude @@ -40,6 +40,21 @@ class Vector3(SrcVector): return AugExpr(self._value_type).bind("{}[{}]", self, coordinate) + def __getitem__(self, idx: int | ExprLike): + return AugExpr(self._value_type).bind("{}[{}]", self, idx) + + +class AABB(AugExpr): + def __init__(self): + dtype = PsCustomType("walberla::AABB") + super().__init__(dtype) + + def min(self) -> Vector3: + return Vector3().bind("{}.min()", self) + + def max(self) -> Vector3: + return Vector3().bind("{}.max()", self) + class BlockDataID(AugExpr): def __init__(self): @@ -57,6 +72,9 @@ class IBlockPtr(AugExpr): def getData(self, dtype: str | PsType, id: BlockDataID) -> AugExpr: return AugExpr.format("{}->template getData< {} >({})", self, dtype, id) + def getAABB(self) -> AABB: + return AABB().bind("{}->getAABB()", self) + @property def required_includes(self) -> set[SfgHeaderInclude]: return {SfgHeaderInclude.parse("domain_decomposition/IBlock.h")} diff --git a/src/sfg_walberla/sweep.py b/src/sfg_walberla/sweep.py index 89e4114f99cd5bbc3ad7411e342cba20f337ee82..c1c2c1f403c677d913578f77922f59fd578077c7 100644 --- a/src/sfg_walberla/sweep.py +++ b/src/sfg_walberla/sweep.py @@ -1,4 +1,4 @@ -from typing import Sequence, Iterable, Callable +from typing import Sequence, Iterable from dataclasses import dataclass, replace from itertools import chain from collections import defaultdict @@ -16,12 +16,20 @@ from pystencils import ( AssignmentCollection, CreateKernelConfig, Field, + TypedSymbol, ) from pystencils.types import PsType, constify, deconstify, PsCustomType from pystencilssfg import SfgComposer from pystencilssfg.lang import VarLike, asvar, SfgVar, AugExpr, Ref, SrcVector -from .api import StructuredBlockForest, GhostLayerFieldPtr, IBlockPtr, BlockDataID, Vector2, Vector3 +from .api import ( + StructuredBlockForest, + GhostLayerFieldPtr, + IBlockPtr, + BlockDataID, + Vector2, + Vector3, +) class SweepClassProperties: @@ -129,6 +137,56 @@ class SweepClassProperties: return sfg.constructor() +class BlockforestParameters: + def __init__( + self, + property_cache: SweepClassProperties, + block: IBlockPtr, + cell_interval: AugExpr | None = None, + ): + self._property_cache = property_cache + self._block = block + self._ci = cell_interval + self._extractions: dict[SfgVar, AugExpr] = dict() + self._properties: list[SweepClassProperties.Property] = [] + + @staticmethod + def process(asms: AssignmentCollection) -> AssignmentCollection: + from sympy.core.function import AppliedUndef + + expandable_appls: set[AppliedUndef] = filter( + lambda expr: hasattr(expr, "expansion_func"), asms.atoms(AppliedUndef) + ) + + subs: dict[AppliedUndef, sp.Symbol] = dict() + for appl in expandable_appls: + expansion: sp.Expr = appl.expansion_func(*appl.args) # type: ignore + symb = next(asms.subexpression_symbol_generator) + asms.subexpressions.insert(0, Assignment(symb, expansion)) + subs[appl] = symb + + return asms.new_with_substitutions(subs) + + def filter_params(self, params: set[SfgVar]) -> set[SfgVar]: + from .symbolic import block, cell + + params_filtered = set() + + for param in params: + for coord in range(3): + if param.name == block.aabb_min[coord].name: + self._extractions[param] = self._block.getAABB().min()[coord] + break + elif param.name == block.aabb_max[coord].name: + self._extractions[param] = self._block.getAABB().max()[coord] + break + # TODO: ci + elif param.name == cell.cell_extents[coord].name: + pass # TODO + else: + params_filtered.add(param) + + @dataclass class FieldInfo: field: Field diff --git a/src/sfg_walberla/symbolic.py b/src/sfg_walberla/symbolic.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6c0763ff4d203429d213250984fa173f772202 --- /dev/null +++ b/src/sfg_walberla/symbolic.py @@ -0,0 +1,68 @@ +import sympy as sp + +import inspect + +from pystencils import TypedSymbol, DynamicType +from pystencils.defaults import DEFAULTS +from pystencils.sympyextensions import CastFunc + + +class BlockCoordinates: + aabb_min = tuple( + TypedSymbol(f"block_aabb_min_{i}", DynamicType.NUMERIC_TYPE) for i in range(3) + ) + aabb_max = tuple( + TypedSymbol(f"block_aabb_max_{i}", DynamicType.NUMERIC_TYPE) for i in range(3) + ) + ci_min = tuple( + TypedSymbol(f"cell_interval_min_{i}", DynamicType.INDEX_TYPE) for i in range(3) + ) + ci_max = tuple( + TypedSymbol(f"cell_interval_max_{i}", DynamicType.INDEX_TYPE) for i in range(3) + ) + + +block = BlockCoordinates + + +def expandable(name: str): + def wrap(expansion_func): + nargs = len(inspect.signature(expansion_func).parameters) + return sp.Function(name, nargs=nargs, expansion_func=staticmethod(expansion_func)) + + return wrap + + +class CellCoordinates: + cell_extents = tuple( + TypedSymbol(f"cell_extents_{i}", DynamicType.NUMERIC_TYPE) for i in range(3) + ) + + @expandable("cell.x") + def x(): + return BlockCoordinates.aabb_min[0] + CellCoordinates.cell_extents[ + 0 + ] * CastFunc.as_numeric(block.ci_min[0] + DEFAULTS.spatial_counters[0]) + + @expandable("cell.y") + def y(): + return BlockCoordinates.aabb_min[1] + CellCoordinates.cell_extents[ + 1 + ] * CastFunc.as_numeric(block.ci_min[1] + DEFAULTS.spatial_counters[1]) + + @expandable("cell.z") + def z(): + return BlockCoordinates.aabb_min[2] + CellCoordinates.cell_extents[ + 2 + ] * CastFunc.as_numeric(block.ci_min[2] + DEFAULTS.spatial_counters[2]) + + local_x = sp.Function("cell.local_x", nargs=0, walberla_info="cell.global_center.x") + local_y = sp.Function("cell.local_y", nargs=0, walberla_info="cell.global_center.y") + local_z = sp.Function("cell.local_z", nargs=0, walberla_info="cell.global_center.z") + + dx = sp.Function("cell.dx", nargs=0, walberla_info="cell.dx") + dy = sp.Function("cell.dy", nargs=0, walberla_info="cell.dy") + dz = sp.Function("cell.dz", nargs=0, walberla_info="cell.dz") + + +cell = CellCoordinates