Skip to content
Snippets Groups Projects
Select Git revision
  • 2d925329a8a53ae7a77a402a379dc7261a60a793
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

data_types.py

Blame
  • data_types.py 17.03 KiB
    import ctypes
    import sympy as sp
    import numpy as np
    try:
        import llvmlite.ir as ir
    except ImportError as e:
        ir = None
        _ir_importerror = e
    from sympy.core.cache import cacheit
    
    from pystencils.cache import memorycache
    from pystencils.utils import all_equal
    from sympy.logic.boolalg import Boolean
    
    
    # noinspection PyPep8Naming
    class cast_func(sp.Function):
        is_Atom = True
    
        def __new__(cls, *args, **kwargs):
            # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
            # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
            # to problems when for example comparing cast_func's for equality
            #
            # lhs = bitwise_and(a, cast_func(1, 'int'))
            # rhs = cast_func(0, 'int')
            # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
            # -> thus a separate class boolean_cast_func is introduced
            if isinstance(args[0], Boolean):
                cls = boolean_cast_func
            return sp.Function.__new__(cls, *args, **kwargs)
    
        @property
        def canonical(self):
            if hasattr(self.args[0], 'canonical'):
                return self.args[0].canonical
            else:
                raise NotImplementedError()
    
        @property
        def is_commutative(self):
            return self.args[0].is_commutative
    
        @property
        def dtype(self):
            return self.args[1]
    
    
    # noinspection PyPep8Naming
    class boolean_cast_func(cast_func, Boolean):
        pass
    
    
    # noinspection PyPep8Naming
    class vector_memory_access(cast_func):
        nargs = (4,)
    
    
    # noinspection PyPep8Naming
    class pointer_arithmetic_func(sp.Function, Boolean):
        @property
        def canonical(self):
            if hasattr(self.args[0], 'canonical'):
                return self.args[0].canonical
            else:
                raise NotImplementedError()
    
    
    class TypedSymbol(sp.Symbol):
        def __new__(cls, *args, **kwds):
            obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
            return obj
    
        def __new_stage2__(cls, name, dtype):
            obj = super(TypedSymbol, cls).__xnew__(cls, name)
            try:
                obj._dtype = create_type(dtype)
            except TypeError:
                # on error keep the string
                obj._dtype = dtype
            return obj
    
        __xnew__ = staticmethod(__new_stage2__)
        __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
    
        @property
        def dtype(self):
            return self._dtype
    
        def _hashable_content(self):
            return super()._hashable_content(), hash(self._dtype)
    
        def __getnewargs__(self):
            return self.name, self.dtype
    
    
    def create_type(specification):
        """Creates a subclass of Type according to a string or an object of subclass Type.
    
        Args:
            specification: Type object, or a string
    
        Returns:
            Type object, or a new Type object parsed from the string
        """
        if isinstance(specification, Type):
            return specification
        else:
            numpy_dtype = np.dtype(specification)
            if numpy_dtype.fields is None:
                return BasicType(numpy_dtype, const=False)
            else:
                return StructType(numpy_dtype, const=False)
    
    
    @memorycache(maxsize=64)
    def create_composite_type_from_string(specification):
        """Creates a new Type object from a c-like string specification.
    
        Args:
            specification: Specification string
    
        Returns:
            Type object
        """
        specification = specification.lower().split()
        parts = []
        current = []
        for s in specification:
            if s == '*':
                parts.append(current)
                current = [s]
            else:
                current.append(s)
        if len(current) > 0:
            parts.append(current)
            # Parse native part
        base_part = parts.pop(0)
        const = False
        if 'const' in base_part:
            const = True
            base_part.remove('const')
        assert len(base_part) == 1
        if base_part[0][-1] == "*":
            base_part[0] = base_part[0][:-1]
            parts.append('*')
        current_type = BasicType(np.dtype(base_part[0]), const)
        # Parse pointer parts
        for part in parts:
            restrict = False
            const = False
            if 'restrict' in part:
                restrict = True
                part.remove('restrict')
            if 'const' in part:
                const = True
                part.remove("const")
            assert len(part) == 1 and part[0] == '*'
            current_type = PointerType(current_type, const, restrict)
        return current_type
    
    
    def get_base_type(data_type):
        while data_type.base_type is not None:
            data_type = data_type.base_type
        return data_type
    
    
    def to_ctypes(data_type):
        """
        Transforms a given Type into ctypes
        :param data_type: Subclass of Type
        :return: ctypes type object
        """
        if isinstance(data_type, PointerType):
            return ctypes.POINTER(to_ctypes(data_type.base_type))
        elif isinstance(data_type, StructType):
            return ctypes.POINTER(ctypes.c_uint8)
        else:
            return to_ctypes.map[data_type.numpy_dtype]
    
    
    to_ctypes.map = {
        np.dtype(np.int8): ctypes.c_int8,
        np.dtype(np.int16): ctypes.c_int16,
        np.dtype(np.int32): ctypes.c_int32,
        np.dtype(np.int64): ctypes.c_int64,
    
        np.dtype(np.uint8): ctypes.c_uint8,
        np.dtype(np.uint16): ctypes.c_uint16,
        np.dtype(np.uint32): ctypes.c_uint32,
        np.dtype(np.uint64): ctypes.c_uint64,
    
        np.dtype(np.float32): ctypes.c_float,
        np.dtype(np.float64): ctypes.c_double,
    }
    
    
    def ctypes_from_llvm(data_type):
        if not ir:
            raise _ir_importerror
        if isinstance(data_type, ir.PointerType):
            ctype = ctypes_from_llvm(data_type.pointee)
            if ctype is None:
                return ctypes.c_void_p
            else:
                return ctypes.POINTER(ctype)
        elif isinstance(data_type, ir.IntType):
            if data_type.width == 8:
                return ctypes.c_int8
            elif data_type.width == 16:
                return ctypes.c_int16
            elif data_type.width == 32:
                return ctypes.c_int32
            elif data_type.width == 64:
                return ctypes.c_int64
            else:
                raise ValueError("Int width %d is not supported" % data_type.width)
        elif isinstance(data_type, ir.FloatType):
            return ctypes.c_float
        elif isinstance(data_type, ir.DoubleType):
            return ctypes.c_double
        elif isinstance(data_type, ir.VoidType):
            return None  # Void type is not supported by ctypes
        else:
            raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
    
    
    def to_llvm_type(data_type):
        """
        Transforms a given type into ctypes
        :param data_type: Subclass of Type
        :return: llvmlite type object
        """
        if not ir:
            raise _ir_importerror
        if isinstance(data_type, PointerType):
            return to_llvm_type(data_type.base_type).as_pointer()
        else:
            return to_llvm_type.map[data_type.numpy_dtype]
    
    
    if ir:
        to_llvm_type.map = {
            np.dtype(np.int8): ir.IntType(8),
            np.dtype(np.int16): ir.IntType(16),
            np.dtype(np.int32): ir.IntType(32),
            np.dtype(np.int64): ir.IntType(64),
    
            np.dtype(np.uint8): ir.IntType(8),
            np.dtype(np.uint16): ir.IntType(16),
            np.dtype(np.uint32): ir.IntType(32),
            np.dtype(np.uint64): ir.IntType(64),
    
            np.dtype(np.float32): ir.FloatType(),
            np.dtype(np.float64): ir.DoubleType(),
        }
    
    
    def peel_off_type(dtype, type_to_peel_off):
        while type(dtype) is type_to_peel_off:
            dtype = dtype.base_type
        return dtype
    
    
    def collate_types(types):
        """
        Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
        Uses the collation rules from numpy.
        """
    
        # Pointer arithmetic case i.e. pointer + integer is allowed
        if any(type(t) is PointerType for t in types):
            pointer_type = None
            for t in types:
                if type(t) is PointerType:
                    if pointer_type is not None:
                        raise ValueError("Cannot collate the combination of two pointer types")
                    pointer_type = t
                elif type(t) is BasicType:
                    if not (t.is_int() or t.is_uint()):
                        raise ValueError("Invalid pointer arithmetic")
                else:
                    raise ValueError("Invalid pointer arithmetic")
            return pointer_type
    
        # peel of vector types, if at least one vector type occurred the result will also be the vector type
        vector_type = [t for t in types if type(t) is VectorType]
        if not all_equal(t.width for t in vector_type):
            raise ValueError("Collation failed because of vector types with different width")
        types = [peel_off_type(t, VectorType) for t in types]
    
        # now we should have a list of basic types - struct types are not yet supported
        assert all(type(t) is BasicType for t in types)
    
        if any(t.is_float() for t in types):
            types = tuple(t for t in types if t.is_float())
        # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
        result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
        result = BasicType(result_numpy_type)
        if vector_type:
            result = VectorType(result, vector_type[0].width)
        return result
    
    
    @memorycache(maxsize=2048)
    def get_type_of_expression(expr):
        from pystencils.astnodes import ResolvedFieldAccess
        expr = sp.sympify(expr)
        if isinstance(expr, sp.Integer):
            return create_type("int")
        elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
            return create_type("double")
        elif isinstance(expr, ResolvedFieldAccess):
            return expr.field.dtype
        elif isinstance(expr, TypedSymbol):
            return expr.dtype
        elif isinstance(expr, sp.Symbol):
            raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
        elif isinstance(expr, cast_func):
            return expr.args[1]
        elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
            collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
            collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
            if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
                collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
            return collated_result_type
        elif isinstance(expr, sp.Indexed):
            typed_symbol = expr.base.label
            return typed_symbol.dtype.base_type
        elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
            # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
            result = create_type("bool")
            vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
            if vec_args:
                result = VectorType(result, width=vec_args[0].width)
            return result
        elif isinstance(expr, sp.Pow):
            return get_type_of_expression(expr.args[0])
        elif isinstance(expr, sp.Expr):
            types = tuple(get_type_of_expression(a) for a in expr.args)
            return collate_types(types)
    
        raise NotImplementedError("Could not determine type for", expr, type(expr))
    
    
    class Type(sp.Basic):
        is_Atom = True
    
        def __new__(cls, *args, **kwargs):
            return sp.Basic.__new__(cls)
    
        def _sympystr(self, *args, **kwargs):
            return str(self)
    
    
    class BasicType(Type):
        @staticmethod
        def numpy_name_to_c(name):
            if name == 'float64':
                return 'double'
            elif name == 'float32':
                return 'float'
            elif name.startswith('int'):
                width = int(name[len("int"):])
                return "int%d_t" % (width,)
            elif name.startswith('uint'):
                width = int(name[len("uint"):])
                return "uint%d_t" % (width,)
            elif name == 'bool':
                return 'bool'
            else:
                raise NotImplementedError("Can map numpy to C name for %s" % (name,))
    
        def __init__(self, dtype, const=False):
            self.const = const
            if isinstance(dtype, Type):
                self._dtype = dtype.numpy_dtype
            else:
                self._dtype = np.dtype(dtype)
            assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
            assert self._dtype.hasobject is False
            assert self._dtype.subdtype is None
    
        def __getnewargs__(self):
            return self.numpy_dtype, self.const
    
        @property
        def base_type(self):
            return None
    
        @property
        def numpy_dtype(self):
            return self._dtype
    
        @property
        def item_size(self):
            return 1
    
        def is_int(self):
            return self.numpy_dtype in np.sctypes['int']
    
        def is_float(self):
            return self.numpy_dtype in np.sctypes['float']
    
        def is_uint(self):
            return self.numpy_dtype in np.sctypes['uint']
    
        def is_complex(self):
            return self.numpy_dtype in np.sctypes['complex']
    
        def is_other(self):
            return self.numpy_dtype in np.sctypes['others']
    
        @property
        def base_name(self):
            return BasicType.numpy_name_to_c(str(self._dtype))
    
        def __str__(self):
            result = BasicType.numpy_name_to_c(str(self._dtype))
            if self.const:
                result += " const"
            return result
    
        def __repr__(self):
            return str(self)
    
        def __eq__(self, other):
            if not isinstance(other, BasicType):
                return False
            else:
                return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
    
        def __hash__(self):
            return hash(str(self))
    
    
    class VectorType(Type):
        instruction_set = None
    
        def __init__(self, base_type, width=4):
            self._base_type = base_type
            self.width = width
    
        @property
        def base_type(self):
            return self._base_type
    
        @property
        def item_size(self):
            return self.width * self.base_type.item_size
    
        def __eq__(self, other):
            if not isinstance(other, VectorType):
                return False
            else:
                return (self.base_type, self.width) == (other.base_type, other.width)
    
        def __str__(self):
            if self.instruction_set is None:
                return "%s[%d]" % (self.base_type, self.width)
            else:
                if self.base_type == create_type("int64"):
                    return self.instruction_set['int']
                elif self.base_type == create_type("float64"):
                    return self.instruction_set['double']
                elif self.base_type == create_type("float32"):
                    return self.instruction_set['float']
                elif self.base_type == create_type("bool"):
                    return self.instruction_set['bool']
                else:
                    raise NotImplementedError()
    
        def __hash__(self):
            return hash((self.base_type, self.width))
    
        def __getnewargs__(self):
            return self._base_type, self.width
    
    
    class PointerType(Type):
        def __init__(self, base_type, const=False, restrict=True):
            self._base_type = base_type
            self.const = const
            self.restrict = restrict
    
        def __getnewargs__(self):
            return self.base_type, self.const, self.restrict
    
        @property
        def alias(self):
            return not self.restrict
    
        @property
        def base_type(self):
            return self._base_type
    
        @property
        def item_size(self):
            return self.base_type.item_size
    
        def __eq__(self, other):
            if not isinstance(other, PointerType):
                return False
            else:
                return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
    
        def __str__(self):
            components = [str(self.base_type), '*']
            if self.restrict:
                components.append('RESTRICT')
            if self.const:
                components.append("const")
            return " ".join(components)
    
        def __repr__(self):
            return str(self)
    
        def __hash__(self):
            return hash((self._base_type, self.const, self.restrict))
    
    
    class StructType:
        def __init__(self, numpy_type, const=False):
            self.const = const
            self._dtype = np.dtype(numpy_type)
    
        def __getnewargs__(self):
            return self.numpy_dtype, self.const
    
        @property
        def base_type(self):
            return None
    
        @property
        def numpy_dtype(self):
            return self._dtype
    
        @property
        def item_size(self):
            return self.numpy_dtype.itemsize
    
        def get_element_offset(self, element_name):
            return self.numpy_dtype.fields[element_name][1]
    
        def get_element_type(self, element_name):
            np_element_type = self.numpy_dtype.fields[element_name][0]
            return BasicType(np_element_type, self.const)
    
        def has_element(self, element_name):
            return element_name in self.numpy_dtype.fields
    
        def __eq__(self, other):
            if not isinstance(other, StructType):
                return False
            else:
                return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
    
        def __str__(self):
            # structs are handled byte-wise
            result = "uint8_t"
            if self.const:
                result += " const"
            return result
    
        def __repr__(self):
            return str(self)
    
        def __hash__(self):
            return hash((self.numpy_dtype, self.const))