From e90b0d553f2dca9a79d0798e8dda3db826d04e6f Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Tue, 25 Oct 2022 10:11:13 +0200 Subject: [PATCH] Clean up --- pystencils/config.py | 27 ++++++++++++++------------- pystencils_tests/test_config.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/pystencils/config.py b/pystencils/config.py index 62958d62..d38876bc 100644 --- a/pystencils/config.py +++ b/pystencils/config.py @@ -139,6 +139,14 @@ class CreateKernelConfig: def __call__(self): return BasicType(self.dt) + def _check_type(self, dtype_to_check): + if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'): + self._typing_error() + + if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"): + # NumPy-types are also of type 'type'. However, they have more properties + self._typing_error() + @staticmethod def _typing_error(): raise ValueError("It is not possible to use python types (float, int) for datatypes because these " @@ -163,30 +171,23 @@ class CreateKernelConfig: # Normalise data types for dtype in [self.data_type, self.default_number_float, self.default_number_int]: - if isinstance(dtype, str) and (dtype == 'float' or dtype == 'int'): - self._typing_error() - - if isinstance(dtype, type): - # NumPy-types are also of type 'type'. However, they have more properties - if not hasattr(dtype, "dtype"): - self._typing_error() + self._check_type(dtype) if not isinstance(self.data_type, dict): dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans self.data_type = defaultdict(self.DataTypeFactory(dt)) if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict): - if any(isinstance(dtype, str) and (dtype == 'float' or dtype == 'int') - for dtype in self.data_type.values()): - self._typing_error() + for dtype in self.data_type.values(): + self._check_type(dtype) - if any(isinstance(dtype, type) and not hasattr(dtype, "dtype") - for dtype in self.data_type.values()): - self._typing_error() dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()]) dtype_dict = self.data_type self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict) + assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!" + self._check_type(self.data_type.default_factory()) + if self.default_number_float is None: self.default_number_float = self.data_type.default_factory() diff --git a/pystencils_tests/test_config.py b/pystencils_tests/test_config.py index a409522b..9824113b 100644 --- a/pystencils_tests/test_config.py +++ b/pystencils_tests/test_config.py @@ -88,3 +88,15 @@ def test_config_python_types5(): def test_config_python_types6(): with pytest.raises(ValueError): CreateKernelConfig(default_number_float="float") + + +def test_config_python_types7(): + dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) + + +def test_config_python_types8(): + dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64}) + with pytest.raises(ValueError): + CreateKernelConfig(data_type=dtype) -- GitLab