Skip to content
Snippets Groups Projects
Commit 941a2a6a authored by Christoph Alt's avatar Christoph Alt
Browse files

Merge branch 'sanedefaults' into 'master'

Sane Defaults for CreateKernelConfig

See merge request !307
parents 8bd3cef5 1dde0c9d
No related branches found
No related tags found
1 merge request!307Sane Defaults for CreateKernelConfig
Pipeline #47301 passed
...@@ -3,13 +3,17 @@ from copy import copy ...@@ -3,13 +3,17 @@ from copy import copy
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from types import MappingProxyType from types import MappingProxyType
from typing import Union, Tuple, List, Dict, Callable, Any from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict
from pystencils import Target, Backend, Field from pystencils import Target, Backend, Field
from pystencils.typing.typed_sympy import BasicType from pystencils.typing.typed_sympy import BasicType
from pystencils.typing.utilities import collate_types
import numpy as np import numpy as np
# TODO: There exists DTypeLike in NumPy which would be better than type for type hinting, to new at the moment
# from numpy.typing import DTypeLike
# TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ... # TODO: CreateKernelConfig is bloated think of more classes better usage, factory whatever ...
# Proposition: CreateKernelConfigs Classes for different targets? # Proposition: CreateKernelConfigs Classes for different targets?
...@@ -30,17 +34,19 @@ class CreateKernelConfig: ...@@ -30,17 +34,19 @@ class CreateKernelConfig:
""" """
Name of the generated function - only important if generated code is written out Name of the generated function - only important if generated code is written out
""" """
# TODO Sane defaults: config should check that the datatype is a Numpy type data_type: Union[type, str, DefaultDict[str, BasicType], Dict[str, BasicType]] = np.float64
# TODO Sane defaults: QoL default_number_float and default_number_int should be data_type if they are not specified
data_type: Union[str, Dict[str, BasicType]] = 'float64'
""" """
Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type Data type used for all untyped symbols (i.e. non-fields), can also be a dict from symbol name to type.
If specified as a dict ideally a defaultdict is used to define a default value for symbols not listed in the
dict. If a plain dict is provided it will be transformed into a defaultdict internally. The default value
will then be specified via type collation then.
""" """
default_number_float: Union[str, np.dtype, BasicType] = 'float64' default_number_float: Union[type, str, BasicType] = None
""" """
Data type used for all untyped floating point numbers (i.e. 0.5) Data type used for all untyped floating point numbers (i.e. 0.5). By default the value of data_type is used.
If data_type is given as a defaultdict its default_factory is used.
""" """
default_number_int: Union[str, np.dtype, BasicType] = 'int64' default_number_int: Union[type, str, BasicType] = np.int64
""" """
Data type used for all untyped integer numbers (i.e. 1) Data type used for all untyped integer numbers (i.e. 1)
""" """
...@@ -133,9 +139,22 @@ class CreateKernelConfig: ...@@ -133,9 +139,22 @@ class CreateKernelConfig:
def __call__(self): def __call__(self):
return BasicType(self.dt) return BasicType(self.dt)
def _check_type(self, dtype_to_check):
if isinstance(dtype_to_check, str) and (dtype_to_check == 'float' or dtype_to_check == 'int'):
self._typing_error()
if isinstance(dtype_to_check, type) and not hasattr(dtype_to_check, "dtype"):
# NumPy-types are also of type 'type'. However, they have more properties
self._typing_error()
@staticmethod
def _typing_error():
raise ValueError("It is not possible to use python types (float, int) for datatypes because these "
"types are ambiguous. For example float will map to double. "
"Also the string version like 'float' is not allowed, e.g. use 'float64' instead")
def __post_init__(self): def __post_init__(self):
# ---- Legacy parameters # ---- Legacy parameters
# TODO Sane defaults: Check for abmigous types like "float", python float, which are dangerous for users
if isinstance(self.target, str): if isinstance(self.target, str):
new_target = Target[self.target.upper()] new_target = Target[self.target.upper()]
warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead', warnings.warn(f'Target "{self.target}" as str is deprecated. Use {new_target} instead',
...@@ -150,10 +169,30 @@ class CreateKernelConfig: ...@@ -150,10 +169,30 @@ class CreateKernelConfig:
else: else:
raise NotImplementedError(f'Target {self.target} has no default backend') raise NotImplementedError(f'Target {self.target} has no default backend')
# Normalise data types # Normalise data types
for dtype in [self.data_type, self.default_number_float, self.default_number_int]:
self._check_type(dtype)
if not isinstance(self.data_type, dict): if not isinstance(self.data_type, dict):
dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans dt = copy(self.data_type) # The copy is necessary because BasicType has sympy shinanigans
self.data_type = defaultdict(self.DataTypeFactory(dt)) self.data_type = defaultdict(self.DataTypeFactory(dt))
if isinstance(self.data_type, dict) and not isinstance(self.data_type, defaultdict):
for dtype in self.data_type.values():
self._check_type(dtype)
dt = collate_types([BasicType(dtype) for dtype in self.data_type.values()])
dtype_dict = self.data_type
self.data_type = defaultdict(self.DataTypeFactory(dt), dtype_dict)
assert isinstance(self.data_type, defaultdict), "At this point data_type must be a defaultdict!"
for dtype in self.data_type.values():
self._check_type(dtype)
self._check_type(self.data_type.default_factory())
if self.default_number_float is None:
self.default_number_float = self.data_type.default_factory()
if not isinstance(self.default_number_float, BasicType): if not isinstance(self.default_number_float, BasicType):
self.default_number_float = BasicType(self.default_number_float) self.default_number_float = BasicType(self.default_number_float)
if not isinstance(self.default_number_int, BasicType): if not isinstance(self.default_number_int, BasicType):
......
...@@ -120,7 +120,10 @@ class BasicType(AbstractType): ...@@ -120,7 +120,10 @@ class BasicType(AbstractType):
return f'{self.c_name}{" const" if self.const else ""}' return f'{self.c_name}{" const" if self.const else ""}'
def __repr__(self): def __repr__(self):
return str(self) return f'BasicType( {str(self)} )'
def _repr_html_(self):
return f'BasicType( {str(self)} )'
def __eq__(self, other): def __eq__(self, other):
return self.dtype_eq(other) and self.const == other.const return self.dtype_eq(other) and self.const == other.const
...@@ -216,6 +219,9 @@ class PointerType(AbstractType): ...@@ -216,6 +219,9 @@ class PointerType(AbstractType):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self): def __hash__(self):
return hash((self._base_type, self.const, self.restrict)) return hash((self._base_type, self.const, self.restrict))
...@@ -273,6 +279,9 @@ class StructType(AbstractType): ...@@ -273,6 +279,9 @@ class StructType(AbstractType):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def _repr_html_(self):
return str(self)
def __hash__(self): def __hash__(self):
return hash((self.numpy_dtype, self.const)) return hash((self.numpy_dtype, self.const))
......
from collections import defaultdict
import numpy as np
import pytest
from pystencils import CreateKernelConfig, Target, Backend
from pystencils.typing import BasicType
def test_config():
# targets
config = CreateKernelConfig(target=Target.CPU)
assert config.target == Target.CPU
assert config.backend == Backend.C
config = CreateKernelConfig(target=Target.GPU)
assert config.target == Target.GPU
assert config.backend == Backend.CUDA
# typing
config = CreateKernelConfig(data_type=np.float64)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float32')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type=np.float32, default_number_float=np.float64, default_number_int=np.int16)
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int16')
config = CreateKernelConfig(data_type='float64')
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type={'a': np.float64, 'b': np.float32})
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float64')
assert config.default_number_float == BasicType('float64')
assert config.default_number_int == BasicType('int64')
config = CreateKernelConfig(data_type={'a': np.float32, 'b': np.int32})
assert isinstance(config.data_type, defaultdict)
assert config.data_type.default_factory() == BasicType('float32')
assert config.default_number_float == BasicType('float32')
assert config.default_number_int == BasicType('int64')
def test_config_python_types():
with pytest.raises(ValueError):
CreateKernelConfig(data_type=float)
def test_config_python_types2():
with pytest.raises(ValueError):
CreateKernelConfig(data_type={'a': float})
def test_config_python_types3():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_float=float)
def test_config_python_types4():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_int=int)
def test_config_python_types5():
with pytest.raises(ValueError):
CreateKernelConfig(data_type="float")
def test_config_python_types6():
with pytest.raises(ValueError):
CreateKernelConfig(default_number_float="float")
def test_config_python_types7():
dtype = defaultdict(lambda: 'float', {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types8():
dtype = defaultdict(lambda: float, {'a': np.float64, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types9():
dtype = defaultdict(lambda: 'float32', {'a': 'float', 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
def test_config_python_types10():
dtype = defaultdict(lambda: 'float32', {'a': float, 'b': np.int64})
with pytest.raises(ValueError):
CreateKernelConfig(data_type=dtype)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment