Skip to content
Snippets Groups Projects

WIP: Revamp the type system

Closed Markus Holzer requested to merge holzer/pystencils:TypeSystem into master
4 files
+ 73
14
Compare changes
  • Side-by-side
  • Inline
Files
4
+ 56
12
@@ -29,11 +29,13 @@ def typed_symbols(names, dtype, *args):
@@ -29,11 +29,13 @@ def typed_symbols(names, dtype, *args):
def type_all_numbers(expr, dtype):
def type_all_numbers(expr, dtype):
 
# TODO: move to pystnecils_walberla
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
return expr.subs(substitutions)
def matrix_symbols(names, dtype, rows, cols):
def matrix_symbols(names, dtype, rows, cols):
 
# TODO: check if needed. (lbmpy, walberla)
if isinstance(names, str):
if isinstance(names, str):
names = names.replace(' ', '').split(',')
names = names.replace(' ', '').split(',')
@@ -46,6 +48,7 @@ def matrix_symbols(names, dtype, rows, cols):
@@ -46,6 +48,7 @@ def matrix_symbols(names, dtype, rows, cols):
def assumptions_from_dtype(dtype):
def assumptions_from_dtype(dtype):
 
# TODO: type hints and if dtype is correct type form Numpy
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
Args:
@@ -76,6 +79,9 @@ def assumptions_from_dtype(dtype):
@@ -76,6 +79,9 @@ def assumptions_from_dtype(dtype):
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class address_of(sp.Function):
class address_of(sp.Function):
 
# TODO: ask Martin
 
# TODO: documentation
 
# TODO: move function to `functions.py`
is_Atom = True
is_Atom = True
def __new__(cls, arg):
def __new__(cls, arg):
@@ -103,6 +109,8 @@ class address_of(sp.Function):
@@ -103,6 +109,8 @@ class address_of(sp.Function):
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class cast_func(sp.Function):
class cast_func(sp.Function):
 
# TODO: documentation
 
# TODO: move function to `functions.py`
is_Atom = True
is_Atom = True
def __new__(cls, *args, **kwargs):
def __new__(cls, *args, **kwargs):
@@ -190,22 +198,30 @@ class cast_func(sp.Function):
@@ -190,22 +198,30 @@ class cast_func(sp.Function):
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
class boolean_cast_func(cast_func, Boolean):
 
# TODO: documentation
 
# TODO: move function to `functions.py`
pass
pass
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
class vector_memory_access(cast_func):
 
# TODO: documentation
 
# TODO: move function to `functions.py`
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
nargs = (6,)
nargs = (6,)
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
class reinterpret_cast_func(cast_func):
 
# TODO: documentation
 
# TODO: move function to `functions.py`
pass
pass
# noinspection PyPep8Naming
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
class pointer_arithmetic_func(sp.Function, Boolean):
 
# TODO: documentation
 
# TODO: move function to `functions.py`
@property
@property
def canonical(self):
def canonical(self):
if hasattr(self.args[0], 'canonical'):
if hasattr(self.args[0], 'canonical'):
@@ -272,6 +288,8 @@ class TypedSymbol(sp.Symbol):
@@ -272,6 +288,8 @@ class TypedSymbol(sp.Symbol):
def create_type(specification):
def create_type(specification):
 
# TODO: HERE
 
# TODO: type hint -> np.type
"""Creates a subclass of Type according to a string or an object of subclass Type.
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
Args:
@@ -292,6 +310,7 @@ def create_type(specification):
@@ -292,6 +310,7 @@ def create_type(specification):
@memorycache(maxsize=64)
@memorycache(maxsize=64)
def create_composite_type_from_string(specification):
def create_composite_type_from_string(specification):
 
# TODO: can be removed after llvm removla and fix of kernelparameters
"""Creates a new Type object from a c-like string specification.
"""Creates a new Type object from a c-like string specification.
Args:
Args:
@@ -338,12 +357,15 @@ def create_composite_type_from_string(specification):
@@ -338,12 +357,15 @@ def create_composite_type_from_string(specification):
def get_base_type(data_type):
def get_base_type(data_type):
 
# TODO: WTF is this?? DOCS!!!
 
# TODO: Can be removed after removal of kerncraft and fix in FieldPointer Symbol
while data_type.base_type is not None:
while data_type.base_type is not None:
data_type = data_type.base_type
data_type = data_type.base_type
return data_type
return data_type
def to_ctypes(data_type):
def to_ctypes(data_type):
 
# TODO: can be removed with llvm
"""
"""
Transforms a given Type into ctypes
Transforms a given Type into ctypes
:param data_type: Subclass of Type
:param data_type: Subclass of Type
@@ -356,7 +378,7 @@ def to_ctypes(data_type):
@@ -356,7 +378,7 @@ def to_ctypes(data_type):
else:
else:
return to_ctypes.map[data_type.numpy_dtype]
return to_ctypes.map[data_type.numpy_dtype]
# TODO: can be removed with llvm
to_ctypes.map = {
to_ctypes.map = {
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int16): ctypes.c_int16,
np.dtype(np.int16): ctypes.c_int16,
@@ -374,6 +396,7 @@ to_ctypes.map = {
@@ -374,6 +396,7 @@ to_ctypes.map = {
def ctypes_from_llvm(data_type):
def ctypes_from_llvm(data_type):
 
# TODO can be removed with LLVM
if not ir:
if not ir:
raise _ir_importerror
raise _ir_importerror
if isinstance(data_type, ir.PointerType):
if isinstance(data_type, ir.PointerType):
@@ -404,6 +427,7 @@ def ctypes_from_llvm(data_type):
@@ -404,6 +427,7 @@ def ctypes_from_llvm(data_type):
def to_llvm_type(data_type, nvvm_target=False):
def to_llvm_type(data_type, nvvm_target=False):
 
# TODO: can be removed with LLVM
"""
"""
Transforms a given type into ctypes
Transforms a given type into ctypes
:param data_type: Subclass of Type
:param data_type: Subclass of Type
@@ -417,6 +441,7 @@ def to_llvm_type(data_type, nvvm_target=False):
@@ -417,6 +441,7 @@ def to_llvm_type(data_type, nvvm_target=False):
return to_llvm_type.map[data_type.numpy_dtype]
return to_llvm_type.map[data_type.numpy_dtype]
 
# TODO: can be removed with LLVM
if ir:
if ir:
to_llvm_type.map = {
to_llvm_type.map = {
np.dtype(np.int8): ir.IntType(8),
np.dtype(np.int8): ir.IntType(8),
@@ -435,16 +460,19 @@ if ir:
@@ -435,16 +460,19 @@ if ir:
def peel_off_type(dtype, type_to_peel_off):
def peel_off_type(dtype, type_to_peel_off):
 
# TODO: WTF is this??? DOCS!!!
 
# TODO: used only once.... can be a lambda there
while type(dtype) is type_to_peel_off:
while type(dtype) is type_to_peel_off:
dtype = dtype.base_type
dtype = dtype.base_type
return dtype
return dtype
 
############################# This is basically our type system ########################################################
def collate_types(types,
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_complex=False, # TODO: type system shouldn't need this!!!
forbid_collation_to_float=False,
forbid_collation_to_float=False, # TODO: type system shouldn't need this!!!
default_float_type='float64',
default_float_type='float64', # TODO: AST leaves should be typed. Expressions should be able to find out correct type
default_int_type='int64'):
default_int_type='int64'): # TODO: AST leaves should be typed. Expressions should be able to find out correct type
"""
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
Uses the collation rules from numpy.
@@ -495,9 +523,9 @@ def collate_types(types,
@@ -495,9 +523,9 @@ def collate_types(types,
@memorycache_if_hashable(maxsize=2048)
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
def get_type_of_expression(expr,
default_float_type='double',
default_float_type='double', # TODO: we shouldn't need to have default. AST leaves should have a type
default_int_type='int',
default_int_type='int', # TODO: we shouldn't need to have default. AST leaves should have a type
symbol_type_dict=None):
symbol_type_dict=None): # TODO: we shouldn't need to have default. AST leaves should have a type
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.cpu.vectorization import vec_all, vec_any
@@ -582,6 +610,7 @@ def get_type_of_expression(expr,
@@ -582,6 +610,7 @@ def get_type_of_expression(expr,
return create_type(default_float_type)
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
raise NotImplementedError("Could not determine type for", expr, type(expr))
 
############################# End This is basically our type system ##################################################
sympy_version = sp.__version__.split('.')
sympy_version = sp.__version__.split('.')
@@ -614,6 +643,8 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
@@ -614,6 +643,8 @@ if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
class Type(sp.Atom):
class Type(sp.Atom):
 
# TODO: why is our type system dependent on sympy???
 
# TODO: ask Martin
def __new__(cls, *args, **kwargs):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
return sp.Basic.__new__(cls)
@@ -622,8 +653,15 @@ class Type(sp.Atom):
@@ -622,8 +653,15 @@ class Type(sp.Atom):
class BasicType(Type):
class BasicType(Type):
 
# TODO: check if Type inheritance is needed
 
# TODO: should be a sensible interface to np.dtype
 
# TODO: read numpy docs (Jan)
@staticmethod
@staticmethod
def numpy_name_to_c(name):
def numpy_name_to_c(name):
 
# TODO: this should be a free function
 
# TODO: also check if numpy has this functionality
 
# TODO: docs!!!
 
# TODO: is this C?
if name == 'float64':
if name == 'float64':
return 'double'
return 'double'
elif name == 'float32':
elif name == 'float32':
@@ -644,9 +682,10 @@ class BasicType(Type):
@@ -644,9 +682,10 @@ class BasicType(Type):
raise NotImplementedError(f"Can map numpy to C name for {name}")
raise NotImplementedError(f"Can map numpy to C name for {name}")
def __init__(self, dtype, const=False):
def __init__(self, dtype, const=False):
 
# TODO: type hints
self.const = const
self.const = const
if isinstance(dtype, Type):
if isinstance(dtype, Type):
self._dtype = dtype.numpy_dtype
self._dtype = dtype.numpy_dtype # TODO: wtf?
else:
else:
self._dtype = np.dtype(dtype)
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
@@ -660,7 +699,7 @@ class BasicType(Type):
@@ -660,7 +699,7 @@ class BasicType(Type):
return (self.numpy_dtype, self.const), {}
return (self.numpy_dtype, self.const), {}
@property
@property
def base_type(self):
def base_type(self): # TODO: what is base_type?
return None
return None
@property
@property
@@ -672,7 +711,7 @@ class BasicType(Type):
@@ -672,7 +711,7 @@ class BasicType(Type):
return getattr(sympy.codegen.ast, str(self.numpy_dtype))
return getattr(sympy.codegen.ast, str(self.numpy_dtype))
@property
@property
def item_size(self):
def item_size(self): # TODO: what is this?
return 1
return 1
def is_int(self):
def is_int(self):
@@ -691,7 +730,7 @@ class BasicType(Type):
@@ -691,7 +730,7 @@ class BasicType(Type):
return self.numpy_dtype in np.sctypes['others']
return self.numpy_dtype in np.sctypes['others']
@property
@property
def base_name(self):
def base_name(self): # TODO: name of the function is highly confusing
return BasicType.numpy_name_to_c(str(self._dtype))
return BasicType.numpy_name_to_c(str(self._dtype))
def __str__(self):
def __str__(self):
@@ -714,6 +753,7 @@ class BasicType(Type):
@@ -714,6 +753,7 @@ class BasicType(Type):
class VectorType(Type):
class VectorType(Type):
 
# TODO: check with rest
instruction_set = None
instruction_set = None
def __init__(self, base_type, width=4):
def __init__(self, base_type, width=4):
@@ -760,6 +800,7 @@ class VectorType(Type):
@@ -760,6 +800,7 @@ class VectorType(Type):
class PointerType(Type):
class PointerType(Type):
 
# TODO: rename to FieldType
def __init__(self, base_type, const=False, restrict=True):
def __init__(self, base_type, const=False, restrict=True):
self._base_type = base_type
self._base_type = base_type
self.const = const
self.const = const
@@ -805,6 +846,7 @@ class PointerType(Type):
@@ -805,6 +846,7 @@ class PointerType(Type):
class StructType:
class StructType:
 
# TODO: Docs. This is a struct. A list of types (with C offsets)
def __init__(self, numpy_type, const=False):
def __init__(self, numpy_type, const=False):
self.const = const
self.const = const
self._dtype = np.dtype(numpy_type)
self._dtype = np.dtype(numpy_type)
@@ -858,6 +900,8 @@ class StructType:
@@ -858,6 +900,8 @@ class StructType:
class TypedImaginaryUnit(TypedSymbol):
class TypedImaginaryUnit(TypedSymbol):
 
# TODO: why is this an extra class???
 
# TODO: remove?
def __new__(cls, *args, **kwds):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
return obj
Loading