Skip to content
Snippets Groups Projects
Commit 63f396e7 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Minor bugfixes:

 - allow `np.integer` args to float-type constant creation
 - fix target checking in datahandling
 - re-enable OpenMP for boundary handling kernels
parent 617a9282
No related branches found
No related tags found
3 merge requests!433Consolidate codegen and JIT modules.,!430Jupyter Inspection Framework, Book Theme, and Initial Drafts for Codegen Reference Guides,!429Iteration Slices: Extended GPU support + bugfixes
Pipeline #70373 passed
...@@ -314,7 +314,7 @@ class BoundaryHandling: ...@@ -314,7 +314,7 @@ class BoundaryHandling:
def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj): def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj):
return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj,
target=self._target,) # cpu_openmp=self._openmp) TODO: replace target=self._target, cpu_openmp=self._openmp)
def _create_index_fields(self): def _create_index_fields(self):
dh = self._data_handling dh = self._data_handling
......
...@@ -291,7 +291,10 @@ class SerialDataHandling(DataHandling): ...@@ -291,7 +291,10 @@ class SerialDataHandling(DataHandling):
def synchronization_function(self, names, stencil=None, target=None, functor=None, **_): def synchronization_function(self, names, stencil=None, target=None, functor=None, **_):
if target is None: if target is None:
target = self.default_target target = self.default_target
assert target in (Target.CPU, Target.GPU)
if not (target.is_cpu() or target == Target.CUDA):
raise ValueError(f"Unsupported target: {target}")
if not hasattr(names, '__len__') or type(names) is str: if not hasattr(names, '__len__') or type(names) is str:
names = [names] names = [names]
...@@ -325,7 +328,7 @@ class SerialDataHandling(DataHandling): ...@@ -325,7 +328,7 @@ class SerialDataHandling(DataHandling):
values_per_cell = values_per_cell[0] values_per_cell = values_per_cell[0]
if len(filtered_stencil) > 0: if len(filtered_stencil) > 0:
if target == Target.CPU: if target.is_cpu():
if functor is None: if functor is None:
from pystencils.slicing import get_periodic_boundary_functor from pystencils.slicing import get_periodic_boundary_functor
functor = get_periodic_boundary_functor functor = get_periodic_boundary_functor
......
...@@ -683,7 +683,7 @@ class PsIeeeFloatType(PsScalarType): ...@@ -683,7 +683,7 @@ class PsIeeeFloatType(PsScalarType):
def create_constant(self, value: Any) -> Any: def create_constant(self, value: Any) -> Any:
np_type = self.NUMPY_TYPES[self._width] np_type = self.NUMPY_TYPES[self._width]
if isinstance(value, (int, float, np.floating)): if isinstance(value, (int, float, np.integer, np.floating)):
finfo = np.finfo(np_type) # type: ignore finfo = np.finfo(np_type) # type: ignore
if value < finfo.min or value > finfo.max: if value < finfo.min or value > finfo.max:
raise PsTypeError( raise PsTypeError(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment