From 63f396e7768d985bcd1c78cfd48220d1239e72f5 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 18 Nov 2024 12:49:51 +0100
Subject: [PATCH] Minor bugfixes:

 - allow `np.integer` args to float-type constant creation
 - fix target checking in datahandling
 - re-enable OpenMP for boundary handling kernels
---
 src/pystencils/boundaries/boundaryhandling.py      | 2 +-
 src/pystencils/datahandling/serial_datahandling.py | 7 +++++--
 src/pystencils/types/types.py                      | 2 +-
 3 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py
index 5c0869c29..fe8dd7d00 100644
--- a/src/pystencils/boundaries/boundaryhandling.py
+++ b/src/pystencils/boundaries/boundaryhandling.py
@@ -314,7 +314,7 @@ class BoundaryHandling:
 
     def _create_boundary_kernel(self, symbolic_field, symbolic_index_field, boundary_obj):
         return create_boundary_kernel(symbolic_field, symbolic_index_field, self.stencil, boundary_obj,
-                                      target=self._target,)  # cpu_openmp=self._openmp) TODO: replace
+                                      target=self._target, cpu_openmp=self._openmp)
 
     def _create_index_fields(self):
         dh = self._data_handling
diff --git a/src/pystencils/datahandling/serial_datahandling.py b/src/pystencils/datahandling/serial_datahandling.py
index 8521dda10..6a5ce5730 100644
--- a/src/pystencils/datahandling/serial_datahandling.py
+++ b/src/pystencils/datahandling/serial_datahandling.py
@@ -291,7 +291,10 @@ class SerialDataHandling(DataHandling):
     def synchronization_function(self, names, stencil=None, target=None, functor=None, **_):
         if target is None:
             target = self.default_target
-        assert target in (Target.CPU, Target.GPU)
+            
+        if not (target.is_cpu() or target == Target.CUDA):
+            raise ValueError(f"Unsupported target: {target}")
+
         if not hasattr(names, '__len__') or type(names) is str:
             names = [names]
 
@@ -325,7 +328,7 @@ class SerialDataHandling(DataHandling):
                 values_per_cell = values_per_cell[0]
 
             if len(filtered_stencil) > 0:
-                if target == Target.CPU:
+                if target.is_cpu():
                     if functor is None:
                         from pystencils.slicing import get_periodic_boundary_functor
                         functor = get_periodic_boundary_functor
diff --git a/src/pystencils/types/types.py b/src/pystencils/types/types.py
index ae751992d..7645a452f 100644
--- a/src/pystencils/types/types.py
+++ b/src/pystencils/types/types.py
@@ -683,7 +683,7 @@ class PsIeeeFloatType(PsScalarType):
     def create_constant(self, value: Any) -> Any:
         np_type = self.NUMPY_TYPES[self._width]
 
-        if isinstance(value, (int, float, np.floating)):
+        if isinstance(value, (int, float, np.integer, np.floating)):
             finfo = np.finfo(np_type)  # type: ignore
             if value < finfo.min or value > finfo.max:
                 raise PsTypeError(
-- 
GitLab