From 63f396e7768d985bcd1c78cfd48220d1239e72f5 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 18 Nov 2024 12:49:51 +0100 Subject: [PATCH] Minor bugfixes: - allow `np.integer` args to float-type constant creation - fix target checking in datahandling - re-enable OpenMP for boundary handling kernels --- src/pystencils/boundaries/boundaryhandling.py | 2 +- src/pystencils/datahandling/serial_datahandling.py | 7 +++++-- src/pystencils/types/types.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index 5c0869c29..fe8dd7d00 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -314,7 +314,7 @@ class BoundaryHandling: def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj): return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj, - target=self._target,) # cpu_openmp=self._openmp) TODO: replace + target=self._target, cpu_openmp=self._openmp) def _create_index_fields(self): dh = self._data_handling diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py index 8521dda10..6a5ce5730 100644 --- a/src/pystencils/datahandling/serial_datahandling.py +++ b/src/pystencils/datahandling/serial_datahandling.py @@ -291,7 +291,10 @@ class SerialDataHandling(DataHandling): def synchronization_function(self, names, stencil=None, target=None, functor=None, **_): if target is None: target = self.default_target - assert target in (Target.CPU, Target.GPU) + + if not (target.is_cpu() or target == Target.CUDA): + raise ValueError(f"Unsupported target: {target}") + if not hasattr(names, '__len__') or type(names) is str: names = [names] @@ -325,7 +328,7 @@ class SerialDataHandling(DataHandling): values_per_cell = values_per_cell[0] if len(filtered_stencil) > 0: - if target == Target.CPU: + if target.is_cpu(): if functor is None: from pystencils.slicing import get_periodic_boundary_functor functor = get_periodic_boundary_functor diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py index ae751992d..7645a452f 100644 --- a/src/pystencils/types/types.py +++ b/src/pystencils/types/types.py @@ -683,7 +683,7 @@ class PsIeeeFloatType(PsScalarType): def create_constant(self, value: Any) -> Any: np_type = self.NUMPY_TYPES[self._width] - if isinstance(value, (int, float, np.floating)): + if isinstance(value, (int, float, np.integer, np.floating)): finfo = np.finfo(np_type) # type: ignore if value < finfo.min or value > finfo.max: raise PsTypeError( -- GitLab