Skip to content
Snippets Groups Projects
Commit 1dde0c9d authored by Markus Holzer's avatar Markus Holzer Committed by Christoph Alt
Browse files

Sane Defaults for CreateKernelConfig

parent 8bd3cef5
Branches
Tags
No related merge requests found
......@@ -3,13 +3,17 @@ from copy import copy
from collections import defaultdict
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Union, Tuple, List, Dict, Callable, Any
from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict
from pystencils import Target, Backend, Field
from pystencils.typing.typed_sympy import BasicType
from pystencils.typing.utilities import collate_types
import numpy as np
# TODO: There exists DTypeLike in NumPy which would be better than type for type hinting, to new at the moment
# from numpy.typing import DTypeLike
# TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ...
# Proposition: CreateKernelConfigs Classes for different targets?
......@@ -30,17 +34,19 @@ class CreateKernelConfig:
"""
Name of the generated function - only important if generated code is written out
"""
# TODO Sane defaults: config should check that the datatype is a Numpy type
# TODO Sane defaults: QoL default_number_float and default_number_int should be data_type if they are not specified
data_type: Union[str, Dict[str, BasicType]] = 'float64'
data_type: Union[type, str, DefaultDict[str, BasicType], Dict[str, BasicType]] = np.float64
"""
Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type
Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type.
If specified as a dict ideally a defaultdict is used to define a default value for symbols not listed in the
dict. If a plain dict is provided it will be transformed into a defaultdict internally. The default value
will then be specified via type collation then.
"""
default_number_float: Union[str, np.dtype, BasicType] = 'float64'
default_number_float: Union[type, str, BasicType] = None
"""
Data type used for all untyped floating point numbers (i.e. 0.5)
Data type used for all untyped floating point numbers (i.e. 0.5). By default the value of data_type is used.
If data_type is given as a defaultdict its default_factory is used.
"""
default_number_int: Union[str, np.dtype, BasicType] = 'int64'
default_number_int: Union[type, str, BasicType] = np.int64
"""
Data type used for all untyped integer numbers (i.e. 1)
"""
......@@ -133,9 +139,22 @@ class CreateKernelConfig:
def __call__(self):
return BasicType(self.dt)
def _check_type(self, dtype_to_check):
if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'):
self._typing_error()
if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"):
# NumPy-types are also of type 'type'. However, they have more properties
self._typing_error()
@staticmethod
def _typing_error():
raise ValueError("It is not possible to use python types (float, int) for datatypes because these "
"types are ambiguous. For example float will map to double. "
"Also the string version like 'float' is not allowed, e.g. use 'float64' instead")
def __post_init__(self):
# ---- Legacy parameters
# TODO Sane defaults: Check for abmigous types like "float", python float, which are dangerous for users
if isinstance(self.target, str):
new_target = Target[self.target.upper()]
warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
......@@ -150,10 +169,30 @@ class CreateKernelConfig:
else:
raise NotImplementedError(f'Target {self.target} has no default backend')
# Normalise data types
# Normalise data types
for dtype in [self.data_type, self.default_number_float, self.default_number_int]:
self._check_type(dtype)
if not isinstance(self.data_type, dict):
dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans
self.data_type = defaultdict(self.DataTypeFactory(dt))
if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
for dtype in self.data_type.values():
self._check_type(dtype)
dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()])
dtype_dict = self.data_type
self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict)
assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!"
for dtype in self.data_type.values():
self._check_type(dtype)
self._check_type(self.data_type.default_factory())
if self.default_number_float is None:
self.default_number_float = self.data_type.default_factory()
if not isinstance(self.default_number_float, BasicType):
self.default_number_float = BasicType(self.default_number_float)
if not isinstance(self.default_number_int, BasicType):
......
......@@ -120,7 +120,10 @@ class BasicType(AbstractType):
return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self):
return str(self)
return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const
......@@ -216,6 +219,9 @@ class PointerType(AbstractType):
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict))
......@@ -273,6 +279,9 @@ class StructType(AbstractType):
def __repr__(self):
return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
......
from collections import defaultdict
import numpy as np
import pytest
from pystencils import CreateKernelConfig, Target, Backend
from pystencils.typing import BasicType
def test_config():
# targets
config = CreateKernelConfig(target=Target.CPU)
assert config.target == Target.CPU
assert config.backend == Backend.C
config = CreateKernelConfig(target=Target.GPU)
assert config.target == Target.GPU
assert config.backend == Backend.CUDA
# typing
config = CreateKernelConfig(data_type=np.float64)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float32')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64, default_number_int=np.int16)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int16')
config = CreateKernelConfig(data_type='float64')
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type={'a': np.float64, 'b': np.float32})
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type={'a': np.float32, 'b': np.int32})
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float32')
assert config.default_number_int == BasicType('int64')
def test_config_python_types():
with pytest.raises(ValueError):
CreateKernelConfig(data_type=float)
def test_config_python_types2():
with pytest.raises(ValueError):
CreateKernelConfig(data_type={'a': float})
def test_config_python_types3():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_float=float)
def test_config_python_types4():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_int=int)
def test_config_python_types5():
with pytest.raises(ValueError):
CreateKernelConfig(data_type="float")
def test_config_python_types6():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_float="float")
def test_config_python_types7():
dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types8():
dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types9():
dtype = defaultdict(lambda: 'float32', {'a': 'float', 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types10():
dtype = defaultdict(lambda: 'float32', {'a': float, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment