-
Michael Kuron authoredMichael Kuron authored
data_types.py 26.88 KiB
import ctypes
from collections import defaultdict
from functools import partial
from typing import Tuple
import numpy as np
import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean, BooleanFunction
import pystencils
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
try:
import llvmlite.ir as ir
except ImportError as e:
ir = None
_ir_importerror = e
def typed_symbols(names, dtype, *args):
symbols = sp.symbols(names, *args)
if isinstance(symbols, Tuple):
return tuple(TypedSymbol(str(s), dtype) for s in symbols)
else:
return TypedSymbol(str(symbols), dtype)
def type_all_numbers(expr, dtype):
substitutions = {a: cast_func(a, dtype) for a in expr.atoms(sp.Number)}
return expr.subs(substitutions)
def matrix_symbols(names, dtype, rows, cols):
if isinstance(names, str):
names = names.replace(' ', '').split(',')
matrices = []
for n in names:
symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))
return tuple(matrices)
def assumptions_from_dtype(dtype):
"""Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
Args:
dtype (BasicType, np.dtype): a Numpy data type
Returns:
A dict of SymPy assumptions
"""
if hasattr(dtype, 'numpy_dtype'):
dtype = dtype.numpy_dtype
assumptions = dict()
try:
if np.issubdtype(dtype, np.integer):
assumptions.update({'integer': True})
if np.issubdtype(dtype, np.unsignedinteger):
assumptions.update({'negative': False})
if np.issubdtype(dtype, np.integer) or \
np.issubdtype(dtype, np.floating):
assumptions.update({'real': True})
except Exception:
pass
return assumptions
# noinspection PyPep8Naming
class address_of(sp.Function):
is_Atom = True
def __new__(cls, arg):
obj = sp.Function.__new__(cls, arg)
return obj
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
@property
def dtype(self):
if hasattr(self.args[0], 'dtype'):
return PointerType(self.args[0].dtype, restrict=True)
else:
return PointerType('void', restrict=True)
# noinspection PyPep8Naming
class cast_func(sp.Function):
is_Atom = True
def __new__(cls, *args, **kwargs):
if len(args) != 2:
pass
expr, dtype, *other_args = args
if not isinstance(dtype, Type):
dtype = create_type(dtype)
# to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
# however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
# to problems when for example comparing cast_func's for equality
#
# lhs = bitwise_and(a, cast_func(1, 'int'))
# rhs = cast_func(0, 'int')
# print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
# -> thus a separate class boolean_cast_func is introduced
if isinstance(expr, Boolean) and (not isinstance(expr, TypedSymbol) or expr.dtype == BasicType(bool)):
cls = boolean_cast_func
return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
@property
def is_commutative(self):
return self.args[0].is_commutative
def _eval_evalf(self, *args, **kwargs):
return self.args[0].evalf()
@property
def dtype(self):
return self.args[1]
@property
def is_integer(self):
"""
Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
else:
return super().is_integer
@property
def is_negative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
return False
return super().is_negative
@property
def is_nonnegative(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if self.is_negative is False:
return True
else:
return super().is_nonnegative
@property
def is_real(self):
"""
See :func:`.TypedSymbol.is_integer`
"""
if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
super().is_real
else:
return super().is_real
# noinspection PyPep8Naming
class boolean_cast_func(cast_func, Boolean):
pass
# noinspection PyPep8Naming
class vector_memory_access(cast_func):
# Arguments are: read/write expression, type, aligned, nontemporal, mask (or none), stride
nargs = (6,)
# noinspection PyPep8Naming
class reinterpret_cast_func(cast_func):
pass
# noinspection PyPep8Naming
class pointer_arithmetic_func(sp.Function, Boolean):
@property
def canonical(self):
if hasattr(self.args[0], 'canonical'):
return self.args[0].canonical
else:
raise NotImplementedError()
class TypedSymbol(sp.Symbol):
def __new__(cls, *args, **kwds):
obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, name, dtype, **kwargs):
assumptions = assumptions_from_dtype(dtype)
assumptions.update(kwargs)
obj = super(TypedSymbol, cls).__xnew__(cls, name, **assumptions)
try:
obj._dtype = create_type(dtype)
except (TypeError, ValueError):
# on error keep the string
obj._dtype = dtype
return obj
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
@property
def dtype(self):
return self._dtype
def _hashable_content(self):
return super()._hashable_content(), hash(self._dtype)
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), self.assumptions0
@property
def canonical(self):
return self
@property
def reversed(self):
return self
@property
def headers(self):
headers = []
try:
if np.issubdtype(self.dtype.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
try:
if np.issubdtype(self.dtype.base_type.numpy_dtype, np.complexfloating):
headers.append('"cuda_complex.hpp"')
except Exception:
pass
return headers
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
Args:
specification: Type object, or a string
Returns:
Type object, or a new Type object parsed from the string
"""
if isinstance(specification, Type):
return specification
else:
numpy_dtype = np.dtype(specification)
if numpy_dtype.fields is None:
return BasicType(numpy_dtype, const=False)
else:
return StructType(numpy_dtype, const=False)
@memorycache(maxsize=64)
def create_composite_type_from_string(specification):
"""Creates a new Type object from a c-like string specification.
Args:
specification: Specification string
Returns:
Type object
"""
specification = specification.lower().split()
parts = []
current = []
for s in specification:
if s == '*':
parts.append(current)
current = [s]
else:
current.append(s)
if len(current) > 0:
parts.append(current)
# Parse native part
base_part = parts.pop(0)
const = False
if 'const' in base_part:
const = True
base_part.remove('const')
assert len(base_part) == 1
if base_part[0][-1] == "*":
base_part[0] = base_part[0][:-1]
parts.append('*')
current_type = BasicType(np.dtype(base_part[0]), const)
# Parse pointer parts
for part in parts:
restrict = False
const = False
if 'restrict' in part:
restrict = True
part.remove('restrict')
if 'const' in part:
const = True
part.remove("const")
assert len(part) == 1 and part[0] == '*'
current_type = PointerType(current_type, const, restrict)
return current_type
def get_base_type(data_type):
while data_type.base_type is not None:
data_type = data_type.base_type
return data_type
def to_ctypes(data_type):
"""
Transforms a given Type into ctypes
:param data_type: Subclass of Type
:return: ctypes type object
"""
if isinstance(data_type, PointerType):
return ctypes.POINTER(to_ctypes(data_type.base_type))
elif isinstance(data_type, StructType):
return ctypes.POINTER(ctypes.c_uint8)
else:
return to_ctypes.map[data_type.numpy_dtype]
to_ctypes.map = {
np.dtype(np.int8): ctypes.c_int8,
np.dtype(np.int16): ctypes.c_int16,
np.dtype(np.int32): ctypes.c_int32,
np.dtype(np.int64): ctypes.c_int64,
np.dtype(np.uint8): ctypes.c_uint8,
np.dtype(np.uint16): ctypes.c_uint16,
np.dtype(np.uint32): ctypes.c_uint32,
np.dtype(np.uint64): ctypes.c_uint64,
np.dtype(np.float32): ctypes.c_float,
np.dtype(np.float64): ctypes.c_double,
}
def ctypes_from_llvm(data_type):
if not ir:
raise _ir_importerror
if isinstance(data_type, ir.PointerType):
ctype = ctypes_from_llvm(data_type.pointee)
if ctype is None:
return ctypes.c_void_p
else:
return ctypes.POINTER(ctype)
elif isinstance(data_type, ir.IntType):
if data_type.width == 8:
return ctypes.c_int8
elif data_type.width == 16:
return ctypes.c_int16
elif data_type.width == 32:
return ctypes.c_int32
elif data_type.width == 64:
return ctypes.c_int64
else:
raise ValueError("Int width %d is not supported" % data_type.width)
elif isinstance(data_type, ir.FloatType):
return ctypes.c_float
elif isinstance(data_type, ir.DoubleType):
return ctypes.c_double
elif isinstance(data_type, ir.VoidType):
return None # Void type is not supported by ctypes
else:
raise NotImplementedError(f'Data type {type(data_type)} of {data_type} is not supported yet')
def to_llvm_type(data_type, nvvm_target=False):
"""
Transforms a given type into ctypes
:param data_type: Subclass of Type
:return: llvmlite type object
"""
if not ir:
raise _ir_importerror
if isinstance(data_type, PointerType):
return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
else:
return to_llvm_type.map[data_type.numpy_dtype]
if ir:
to_llvm_type.map = {
np.dtype(np.int8): ir.IntType(8),
np.dtype(np.int16): ir.IntType(16),
np.dtype(np.int32): ir.IntType(32),
np.dtype(np.int64): ir.IntType(64),
np.dtype(np.uint8): ir.IntType(8),
np.dtype(np.uint16): ir.IntType(16),
np.dtype(np.uint32): ir.IntType(32),
np.dtype(np.uint64): ir.IntType(64),
np.dtype(np.float32): ir.FloatType(),
np.dtype(np.float64): ir.DoubleType(),
}
def peel_off_type(dtype, type_to_peel_off):
while type(dtype) is type_to_peel_off:
dtype = dtype.base_type
return dtype
def collate_types(types,
forbid_collation_to_complex=False,
forbid_collation_to_float=False,
default_float_type='float64',
default_int_type='int64'):
"""
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy.
"""
if forbid_collation_to_complex:
types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
if not types:
return create_type(default_float_type)
if forbid_collation_to_float:
types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
if not types:
return create_type(default_int_type)
# Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types):
pointer_type = None
for t in types:
if type(t) is PointerType:
if pointer_type is not None:
raise ValueError("Cannot collate the combination of two pointer types")
pointer_type = t
elif type(t) is BasicType:
if not (t.is_int() or t.is_uint()):
raise ValueError("Invalid pointer arithmetic")
else:
raise ValueError("Invalid pointer arithmetic")
return pointer_type
# peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if type(t) is VectorType]
if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width")
types = [peel_off_type(t, VectorType) for t in types]
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
if any(t.is_float() for t in types):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
if vector_type:
result = VectorType(result, vector_type[0].width)
return result
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if default_float_type == 'float':
default_float_type = 'float32'
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
elif expr.is_real is False:
return create_type((np.zeros((1,), default_float_type) * 1j).dtype)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess):
return expr.field.dtype
elif isinstance(expr, pystencils.field.Field.AbstractAccess):
return expr.field.dtype
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
if symbol_type_dict:
return symbol_type_dict[expr.name]
else:
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func):
return expr.args[1]
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label
return typed_symbol.dtype.base_type
elif isinstance(expr, (Boolean, BooleanFunction)):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type(a) for a in expr.args)
# collate_types checks numpy_dtype in the special cases
if any(not hasattr(t, 'numpy_dtype') for t in types):
forbid_collation_to_complex = False
forbid_collation_to_float = False
else:
forbid_collation_to_complex = expr.is_real is True
forbid_collation_to_float = expr.is_integer is True
return collate_types(
types,
forbid_collation_to_complex=forbid_collation_to_complex,
forbid_collation_to_float=forbid_collation_to_float,
default_float_type=default_float_type,
default_int_type=default_int_type)
else:
if expr.is_integer:
return create_type(default_int_type)
else:
return create_type(default_float_type)
raise NotImplementedError("Could not determine type for", expr, type(expr))
sympy_version = sp.__version__.split('.')
if int(sympy_version[0]) * 100 + int(sympy_version[1]) >= 109:
# __setstate__ would bypass the contructor, so we remove it
sp.Number.__getstate__ = sp.Basic.__getstate__
del sp.Basic.__getstate__
class FunctorWithStoredKwargs:
def __init__(self, func, **kwargs):
self.func = func
self.kwargs = kwargs
def __call__(self, *args):
return self.func(*args, **self.kwargs)
# __reduce_ex__ would strip kwargs, so we override it
def basic_reduce_ex(self, protocol):
if hasattr(self, '__getnewargs_ex__'):
args, kwargs = self.__getnewargs_ex__()
else:
args, kwargs = self.__getnewargs__(), {}
if hasattr(self, '__getstate__'):
state = self.__getstate__()
else:
state = None
return FunctorWithStoredKwargs(type(self), **kwargs), args, state
sp.Number.__reduce_ex__ = sp.Basic.__reduce_ex__
sp.Basic.__reduce_ex__ = basic_reduce_ex
class Type(sp.Atom):
def __new__(cls, *args, **kwargs):
return sp.Basic.__new__(cls)
def _sympystr(self, *args, **kwargs):
return str(self)
class BasicType(Type):
@staticmethod
def numpy_name_to_c(name):
if name == 'float64':
return 'double'
elif name == 'float32':
return 'float'
elif name == 'complex64':
return 'ComplexFloat'
elif name == 'complex128':
return 'ComplexDouble'
elif name.startswith('int'):
width = int(name[len("int"):])
return "int%d_t" % (width,)
elif name.startswith('uint'):
width = int(name[len("uint"):])
return "uint%d_t" % (width,)
elif name == 'bool':
return 'bool'
else:
raise NotImplementedError(f"Can map numpy to C name for {name}")
def __init__(self, dtype, const=False):
self.const = const
if isinstance(dtype, Type):
self._dtype = dtype.numpy_dtype
else:
self._dtype = np.dtype(dtype)
assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
assert self._dtype.hasobject is False
assert self._dtype.subdtype is None
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def sympy_dtype(self):
return getattr(sympy.codegen.ast, str(self.numpy_dtype))
@property
def item_size(self):
return 1
def is_int(self):
return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
def is_float(self):
return self.numpy_dtype in np.sctypes['float']
def is_uint(self):
return self.numpy_dtype in np.sctypes['uint']
def is_complex(self):
return self.numpy_dtype in np.sctypes['complex']
def is_other(self):
return self.numpy_dtype in np.sctypes['others']
@property
def base_name(self):
return BasicType.numpy_name_to_c(str(self._dtype))
def __str__(self):
result = BasicType.numpy_name_to_c(str(self._dtype))
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def __eq__(self, other):
if not isinstance(other, BasicType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __hash__(self):
return hash(str(self))
class VectorType(Type):
instruction_set = None
def __init__(self, base_type, width=4):
self._base_type = base_type
self.width = width
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.width * self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, VectorType):
return False
else:
return (self.base_type, self.width) == (other.base_type, other.width)
def __str__(self):
if self.instruction_set is None:
return "%s[%s]" % (self.base_type, self.width)
else:
if self.base_type == create_type("int64") or self.base_type == create_type("int32"):
return self.instruction_set['int']
elif self.base_type == create_type("float64"):
return self.instruction_set['double']
elif self.base_type == create_type("float32"):
return self.instruction_set['float']
elif self.base_type == create_type("bool"):
return self.instruction_set['bool']
else:
raise NotImplementedError()
def __hash__(self):
return hash((self.base_type, self.width))
def __getnewargs__(self):
return self._base_type, self.width
def __getnewargs_ex__(self):
return (self._base_type, self.width), {}
class PointerType(Type):
def __init__(self, base_type, const=False, restrict=True):
self._base_type = base_type
self.const = const
self.restrict = restrict
def __getnewargs__(self):
return self.base_type, self.const, self.restrict
def __getnewargs_ex__(self):
return (self.base_type, self.const, self.restrict), {}
@property
def alias(self):
return not self.restrict
@property
def base_type(self):
return self._base_type
@property
def item_size(self):
return self.base_type.item_size
def __eq__(self, other):
if not isinstance(other, PointerType):
return False
else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
def __str__(self):
components = [str(self.base_type), '*']
if self.restrict:
components.append('RESTRICT')
if self.const:
components.append("const")
return " ".join(components)
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self._base_type, self.const, self.restrict))
class StructType:
def __init__(self, numpy_type, const=False):
self.const = const
self._dtype = np.dtype(numpy_type)
def __getnewargs__(self):
return self.numpy_dtype, self.const
def __getnewargs_ex__(self):
return (self.numpy_dtype, self.const), {}
@property
def base_type(self):
return None
@property
def numpy_dtype(self):
return self._dtype
@property
def item_size(self):
return self.numpy_dtype.itemsize
def get_element_offset(self, element_name):
return self.numpy_dtype.fields[element_name][1]
def get_element_type(self, element_name):
np_element_type = self.numpy_dtype.fields[element_name][0]
return BasicType(np_element_type, self.const)
def has_element(self, element_name):
return element_name in self.numpy_dtype.fields
def __eq__(self, other):
if not isinstance(other, StructType):
return False
else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
def __str__(self):
# structs are handled byte-wise
result = "uint8_t"
if self.const:
result += " const"
return result
def __repr__(self):
return str(self)
def __hash__(self):
return hash((self.numpy_dtype, self.const))
class TypedImaginaryUnit(TypedSymbol):
def __new__(cls, *args, **kwds):
obj = TypedImaginaryUnit.__xnew_cached_(cls, *args, **kwds)
return obj
def __new_stage2__(cls, dtype):
obj = super(TypedImaginaryUnit, cls).__xnew__(cls,
"_i",
dtype,
imaginary=True)
return obj
headers = ['"cuda_complex.hpp"']
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return (self.dtype,)
def __getnewargs_ex__(self):
return (self.dtype,), {}