Skip to content
Snippets Groups Projects
Commit a837cb73 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Permit custom numerical data type in BoundaryHandling

parent 6df2c640
No related branches found
No related tags found
1 merge request!460Fix data types in boundary handling. Fix deprecation checks.
Pipeline #77907 passed
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
import sympy as sp import sympy as sp
from pystencils import create_kernel, CreateKernelConfig, Target from pystencils import create_kernel, CreateKernelConfig, Target
from pystencils.types import UserTypeSpec, create_numeric_type
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.boundaries.createindexlist import ( from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object) create_boundary_index_array, numpy_data_type_for_boundary_object)
...@@ -84,13 +85,14 @@ class FlagInterface: ...@@ -84,13 +85,14 @@ class FlagInterface:
class BoundaryHandling: class BoundaryHandling:
def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None, def __init__(self, data_handling, field_name, stencil, name="boundary_handling", flag_interface=None,
target: Target = Target.CPU, openmp=True): target: Target = Target.CPU, default_dtype: UserTypeSpec = "float64", openmp=True):
assert data_handling.has_data(field_name) assert data_handling.has_data(field_name)
assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match" assert data_handling.dim == len(stencil[0]), "Dimension of stencil and data handling do not match"
self._data_handling = data_handling self._data_handling = data_handling
self._field_name = field_name self._field_name = field_name
self._index_array_name = name + "IndexArrays" self._index_array_name = name + "IndexArrays"
self._target = target self._target = target
self._default_dtype = create_numeric_type(default_dtype)
self._openmp = openmp self._openmp = openmp
self._boundary_object_to_boundary_info = {} self._boundary_object_to_boundary_info = {}
self.stencil = stencil self.stencil = stencil
...@@ -313,8 +315,11 @@ class BoundaryHandling: ...@@ -313,8 +315,11 @@ class BoundaryHandling:
return self._boundary_object_to_boundary_info[boundary_obj].flag return self._boundary_object_to_boundary_info[boundary_obj].flag
def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj): def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj):
return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, cfg = CreateKernelConfig()
target=self._target, cpu_openmp=self._openmp) cfg.target = self._target
cfg.default_dtype = self._default_dtype
cfg.cpu.openmp.enable = self._openmp
return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, cfg)
def _create_index_fields(self): def _create_index_fields(self):
dh = self._data_handling dh = self._data_handling
...@@ -452,11 +457,14 @@ class BoundaryOffsetInfo: ...@@ -452,11 +457,14 @@ class BoundaryOffsetInfo:
return sp.Symbol("invdir") return sp.Symbol("invdir")
def create_boundary_kernel(field, index_field, stencil, boundary_functor, target=Target.CPU, **kernel_creation_args): def create_boundary_kernel(field, index_field, stencil, boundary_functor, cfg: CreateKernelConfig):
# TODO: reconsider how to control the index_dtype in boundary kernels # TODO: reconsider how to control the index_dtype in boundary kernels
config = CreateKernelConfig(index_field=index_field, target=target, index_dtype=SInt(32), **kernel_creation_args) config = cfg.copy()
config.index_field = index_field
idx_dtype = SInt(32)
config.index_dtype = idx_dtype
offset_info = BoundaryOffsetInfo(stencil, config.index_dtype) offset_info = BoundaryOffsetInfo(stencil, idx_dtype)
elements = offset_info.get_array_declarations() elements = offset_info.get_array_declarations()
dir_symbol = TypedSymbol("dir", config.index_dtype) dir_symbol = TypedSymbol("dir", config.index_dtype)
elements += [Assignment(dir_symbol, index_field[0]('dir'))] elements += [Assignment(dir_symbol, index_field[0]('dir'))]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment