Skip to content
Snippets Groups Projects
Commit e90b0d55 authored by Markus Holzer's avatar Markus Holzer
Browse files

Clean up

parent 6f9648ee
Branches
Tags last/OpenCL
No related merge requests found
Pipeline #47298 passed
......@@ -139,6 +139,14 @@ class CreateKernelConfig:
def __call__(self):
return BasicType(self.dt)
def _check_type(self, dtype_to_check):
if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'):
self._typing_error()
if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"):
# NumPy-types are also of type 'type'. However, they have more properties
self._typing_error()
@staticmethod
def _typing_error():
raise ValueError("It is not possible to use python types (float, int) for datatypes because these "
......@@ -163,30 +171,23 @@ class CreateKernelConfig:
# Normalise data types
for dtype in [self.data_type, self.default_number_float, self.default_number_int]:
if isinstance(dtype, str) and (dtype == 'float' or dtype == 'int'):
self._typing_error()
if isinstance(dtype, type):
# NumPy-types are also of type 'type'. However, they have more properties
if not hasattr(dtype, "dtype"):
self._typing_error()
self._check_type(dtype)
if not isinstance(self.data_type, dict):
dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans
self.data_type = defaultdict(self.DataTypeFactory(dt))
if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
if any(isinstance(dtype, str) and (dtype == 'float' or dtype == 'int')
for dtype in self.data_type.values()):
self._typing_error()
for dtype in self.data_type.values():
self._check_type(dtype)
if any(isinstance(dtype, type) and not hasattr(dtype, "dtype")
for dtype in self.data_type.values()):
self._typing_error()
dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()])
dtype_dict = self.data_type
self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict)
assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!"
self._check_type(self.data_type.default_factory())
if self.default_number_float is None:
self.default_number_float = self.data_type.default_factory()
......
......@@ -88,3 +88,15 @@ def test_config_python_types5():
def test_config_python_types6():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_float="float")
def test_config_python_types7():
dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types8():
dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment