Skip to content
Snippets Groups Projects
Select Git revision
  • 443527ae8f93b686ac5499bd45af6e1b52bc4f08
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

main.c

Blame
  • data_types.py 22.42 KiB
    import ctypes
    from collections import defaultdict
    from functools import partial
    from typing import Tuple
    
    import numpy as np
    import sympy as sp
    import sympy.codegen.ast
    from sympy.core.cache import cacheit
    from sympy.logic.boolalg import Boolean
    
    import pystencils
    from pystencils.cache import memorycache, memorycache_if_hashable
    from pystencils.utils import all_equal
    
    try:
        import llvmlite.ir as ir
    except ImportError as e:
        ir = None
        _ir_importerror = e
    
    
    def typed_symbols(names, dtype, *args):
        symbols = sp.symbols(names, *args)
        if isinstance(symbols, Tuple):
            return tuple(TypedSymbol(str(s), dtype) for s in symbols)
        else:
            return TypedSymbol(str(symbols), dtype)
    
    
    def matrix_symbols(names, dtype, rows, cols):
        if isinstance(names, str):
            names = names.replace(' ', '').split(',')
    
        matrices = []
        for n in names:
            symbols = typed_symbols("%s:%i" % (n, rows * cols), dtype)
            matrices.append(sp.Matrix(rows, cols, lambda i, j: symbols[i * cols + j]))
    
        return tuple(matrices)
    
    
    def assumptions_from_dtype(dtype):
        """Derives SymPy assumptions from :class:`BasicType` or a Numpy dtype
    
        Args:
            dtype (BasicType, np.dtype): a Numpy data type
        Returns:
            A dict of SymPy assumptions
        """
        if hasattr(dtype, 'numpy_dtype'):
            dtype = dtype.numpy_dtype
    
        assumptions = dict()
    
        try:
            if np.issubdtype(dtype, np.integer):
                assumptions.update({'integer': True})
    
            if np.issubdtype(dtype, np.unsignedinteger):
                assumptions.update({'negative': False})
    
            if np.issubdtype(dtype, np.integer) or \
                    np.issubdtype(dtype, np.floating):
                assumptions.update({'real': True})
        except Exception:
            pass
    
        return assumptions
    
    
    # noinspection PyPep8Naming
    class address_of(sp.Function):
        is_Atom = True
    
        def __new__(cls, arg):
            obj = sp.Function.__new__(cls, arg)
            return obj
    
        @property
        def canonical(self):
            if hasattr(self.args[0], 'canonical'):
                return self.args[0].canonical
            else:
                raise NotImplementedError()
    
        @property
        def is_commutative(self):
            return self.args[0].is_commutative
    
        @property
        def dtype(self):
            if hasattr(self.args[0], 'dtype'):
                return PointerType(self.args[0].dtype, restrict=True)
            else:
                return PointerType('void', restrict=True)
    
    
    # noinspection PyPep8Naming
    class cast_func(sp.Function):
        is_Atom = True
    
        def __new__(cls, *args, **kwargs):
            if len(args) != 2:
                pass
            expr, dtype, *other_args = args
            if not isinstance(dtype, Type):
                dtype = create_type(dtype)
            # to work in conditions of sp.Piecewise cast_func has to be of type Boolean as well
            # however, a cast_function should only be a boolean if its argument is a boolean, otherwise this leads
            # to problems when for example comparing cast_func's for equality
            #
            # lhs = bitwise_and(a, cast_func(1, 'int'))
            # rhs = cast_func(0, 'int')
            # print( sp.Ne(lhs, rhs) ) # would give true if all cast_funcs are booleans
            # -> thus a separate class boolean_cast_func is introduced
            if isinstance(expr, Boolean):
                cls = boolean_cast_func
    
            return sp.Function.__new__(cls, expr, dtype, *other_args, **kwargs)
    
        @property
        def canonical(self):
            if hasattr(self.args[0], 'canonical'):
                return self.args[0].canonical
            else:
                raise NotImplementedError()
    
        @property
        def is_commutative(self):
            return self.args[0].is_commutative
    
        def _eval_evalf(self, *args, **kwargs):
            return self.args[0].evalf()
    
        @property
        def dtype(self):
            return self.args[1]
    
        @property
        def is_integer(self):
            """
            Uses Numpy type hierarchy to determine :func:`sympy.Expr.is_integer` predicate
    
            For reference: Numpy type hierarchy https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
            """
            if hasattr(self.dtype, 'numpy_dtype'):
                return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer
            else:
                return super().is_integer
    
        @property
        def is_negative(self):
            """
            See :func:`.TypedSymbol.is_integer`
            """
            if hasattr(self.dtype, 'numpy_dtype'):
                if np.issubdtype(self.dtype.numpy_dtype, np.unsignedinteger):
                    return False
    
            return super().is_negative
    
        @property
        def is_nonnegative(self):
            """
            See :func:`.TypedSymbol.is_integer`
            """
            if self.is_negative is False:
                return True
            else:
                return super().is_nonnegative
    
        @property
        def is_real(self):
            """
            See :func:`.TypedSymbol.is_integer`
            """
            if hasattr(self.dtype, 'numpy_dtype'):
                return np.issubdtype(self.dtype.numpy_dtype, np.integer) or \
                    np.issubdtype(self.dtype.numpy_dtype, np.floating) or \
                    super().is_real
            else:
                return super().is_real
    
    
    # noinspection PyPep8Naming
    class boolean_cast_func(cast_func, Boolean):
        pass
    
    
    # noinspection PyPep8Naming
    class vector_memory_access(cast_func):
        nargs = (4,)
    
    
    # noinspection PyPep8Naming
    class reinterpret_cast_func(cast_func):
        pass
    
    
    # noinspection PyPep8Naming
    class pointer_arithmetic_func(sp.Function, Boolean):
        @property
        def canonical(self):
            if hasattr(self.args[0], 'canonical'):
                return self.args[0].canonical
            else:
                raise NotImplementedError()
    
    
    class TypedSymbol(sp.Symbol):
        def __new__(cls, *args, **kwds):
            obj = TypedSymbol.__xnew_cached_(cls, *args, **kwds)
            return obj
    
        def __new_stage2__(cls, name, dtype, *args, **kwargs):
            assumptions = assumptions_from_dtype(dtype)
            obj = super(TypedSymbol, cls).__xnew__(cls, name, *args, **assumptions, **kwargs)
            try:
                obj._dtype = create_type(dtype)
            except (TypeError, ValueError):
                # on error keep the string
                obj._dtype = dtype
            return obj
    
        __xnew__ = staticmethod(__new_stage2__)
        __xnew_cached_ = staticmethod(cacheit(__new_stage2__))
    
        @property
        def dtype(self):
            return self._dtype
    
        def _hashable_content(self):
            return super()._hashable_content(), hash(self._dtype)
    
        def __getnewargs__(self):
            return self.name, self.dtype
    
        @property
        def canonical(self):
            return self
    
        @property
        def reversed(self):
            return self
    
    
    def create_type(specification):
        """Creates a subclass of Type according to a string or an object of subclass Type.
    
        Args:
            specification: Type object, or a string
    
        Returns:
            Type object, or a new Type object parsed from the string
        """
        if isinstance(specification, Type):
            return specification
        else:
            numpy_dtype = np.dtype(specification)
            if numpy_dtype.fields is None:
                return BasicType(numpy_dtype, const=False)
            else:
                return StructType(numpy_dtype, const=False)
    
    
    @memorycache(maxsize=64)
    def create_composite_type_from_string(specification):
        """Creates a new Type object from a c-like string specification.
    
        Args:
            specification: Specification string
    
        Returns:
            Type object
        """
        specification = specification.lower().split()
        parts = []
        current = []
        for s in specification:
            if s == '*':
                parts.append(current)
                current = [s]
            else:
                current.append(s)
        if len(current) > 0:
            parts.append(current)
            # Parse native part
        base_part = parts.pop(0)
        const = False
        if 'const' in base_part:
            const = True
            base_part.remove('const')
        assert len(base_part) == 1
        if base_part[0][-1] == "*":
            base_part[0] = base_part[0][:-1]
            parts.append('*')
        current_type = BasicType(np.dtype(base_part[0]), const)
        # Parse pointer parts
        for part in parts:
            restrict = False
            const = False
            if 'restrict' in part:
                restrict = True
                part.remove('restrict')
            if 'const' in part:
                const = True
                part.remove("const")
            assert len(part) == 1 and part[0] == '*'
            current_type = PointerType(current_type, const, restrict)
        return current_type
    
    
    def get_base_type(data_type):
        while data_type.base_type is not None:
            data_type = data_type.base_type
        return data_type
    
    
    def to_ctypes(data_type):
        """
        Transforms a given Type into ctypes
        :param data_type: Subclass of Type
        :return: ctypes type object
        """
        if isinstance(data_type, PointerType):
            return ctypes.POINTER(to_ctypes(data_type.base_type))
        elif isinstance(data_type, StructType):
            return ctypes.POINTER(ctypes.c_uint8)
        else:
            return to_ctypes.map[data_type.numpy_dtype]
    
    
    to_ctypes.map = {
        np.dtype(np.int8): ctypes.c_int8,
        np.dtype(np.int16): ctypes.c_int16,
        np.dtype(np.int32): ctypes.c_int32,
        np.dtype(np.int64): ctypes.c_int64,
    
        np.dtype(np.uint8): ctypes.c_uint8,
        np.dtype(np.uint16): ctypes.c_uint16,
        np.dtype(np.uint32): ctypes.c_uint32,
        np.dtype(np.uint64): ctypes.c_uint64,
    
        np.dtype(np.float32): ctypes.c_float,
        np.dtype(np.float64): ctypes.c_double,
    }
    
    
    def ctypes_from_llvm(data_type):
        if not ir:
            raise _ir_importerror
        if isinstance(data_type, ir.PointerType):
            ctype = ctypes_from_llvm(data_type.pointee)
            if ctype is None:
                return ctypes.c_void_p
            else:
                return ctypes.POINTER(ctype)
        elif isinstance(data_type, ir.IntType):
            if data_type.width == 8:
                return ctypes.c_int8
            elif data_type.width == 16:
                return ctypes.c_int16
            elif data_type.width == 32:
                return ctypes.c_int32
            elif data_type.width == 64:
                return ctypes.c_int64
            else:
                raise ValueError("Int width %d is not supported" % data_type.width)
        elif isinstance(data_type, ir.FloatType):
            return ctypes.c_float
        elif isinstance(data_type, ir.DoubleType):
            return ctypes.c_double
        elif isinstance(data_type, ir.VoidType):
            return None  # Void type is not supported by ctypes
        else:
            raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type))
    
    
    def to_llvm_type(data_type, nvvm_target=False):
        """
        Transforms a given type into ctypes
        :param data_type: Subclass of Type
        :return: llvmlite type object
        """
        if not ir:
            raise _ir_importerror
        if isinstance(data_type, PointerType):
            return to_llvm_type(data_type.base_type).as_pointer(1 if nvvm_target else 0)
        else:
            return to_llvm_type.map[data_type.numpy_dtype]
    
    
    if ir:
        to_llvm_type.map = {
            np.dtype(np.int8): ir.IntType(8),
            np.dtype(np.int16): ir.IntType(16),
            np.dtype(np.int32): ir.IntType(32),
            np.dtype(np.int64): ir.IntType(64),
    
            np.dtype(np.uint8): ir.IntType(8),
            np.dtype(np.uint16): ir.IntType(16),
            np.dtype(np.uint32): ir.IntType(32),
            np.dtype(np.uint64): ir.IntType(64),
    
            np.dtype(np.float32): ir.FloatType(),
            np.dtype(np.float64): ir.DoubleType(),
        }
    
    
    def peel_off_type(dtype, type_to_peel_off):
        while type(dtype) is type_to_peel_off:
            dtype = dtype.base_type
        return dtype
    
    
    def collate_types(types, forbid_collation_to_float=False):
        """
        Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
        Uses the collation rules from numpy.
        """
    
        if forbid_collation_to_float:
            types = [t for t in types if not (hasattr(t, 'is_float') and t.is_float())]
            if not types:
                return create_type('int32')
    
        # Pointer arithmetic case i.e. pointer + integer is allowed
        if any(type(t) is PointerType for t in types):
            pointer_type = None
            for t in types:
                if type(t) is PointerType:
                    if pointer_type is not None:
                        raise ValueError("Cannot collate the combination of two pointer types")
                    pointer_type = t
                elif type(t) is BasicType:
                    if not (t.is_int() or t.is_uint()):
                        raise ValueError("Invalid pointer arithmetic")
                else:
                    raise ValueError("Invalid pointer arithmetic")
            return pointer_type
    
        # peel of vector types, if at least one vector type occurred the result will also be the vector type
        vector_type = [t for t in types if type(t) is VectorType]
        if not all_equal(t.width for t in vector_type):
            raise ValueError("Collation failed because of vector types with different width")
        types = [peel_off_type(t, VectorType) for t in types]
    
        # now we should have a list of basic types - struct types are not yet supported
        assert all(type(t) is BasicType for t in types)
    
        if any(t.is_float() for t in types):
            types = tuple(t for t in types if t.is_float())
        # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
        result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
        result = BasicType(result_numpy_type)
        if vector_type:
            result = VectorType(result, vector_type[0].width)
        return result
    
    
    @memorycache_if_hashable(maxsize=2048)
    def get_type_of_expression(expr,
                               default_float_type='double',
                               default_int_type='int',
                               symbol_type_dict=None):
        from pystencils.astnodes import ResolvedFieldAccess
        from pystencils.cpu.vectorization import vec_all, vec_any
    
        if not symbol_type_dict:
            symbol_type_dict = defaultdict(lambda: create_type('double'))
    
        get_type = partial(get_type_of_expression,
                           default_float_type=default_float_type,
                           default_int_type=default_int_type,
                           symbol_type_dict=symbol_type_dict)
    
        expr = sp.sympify(expr)
        if isinstance(expr, sp.Integer):
            return create_type(default_int_type)
        elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
            return create_type(default_float_type)
        elif isinstance(expr, ResolvedFieldAccess):
            return expr.field.dtype
        elif isinstance(expr, pystencils.field.Field.AbstractAccess):
            return expr.field.dtype
        elif isinstance(expr, TypedSymbol):
            return expr.dtype
        elif isinstance(expr, sp.Symbol):
            if symbol_type_dict:
                return symbol_type_dict[expr.name]
            else:
                raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
        elif isinstance(expr, cast_func):
            return expr.args[1]
        elif isinstance(expr, (vec_any, vec_all)):
            return create_type("bool")
        elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
            collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
            collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
            if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
                collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
            return collated_result_type
        elif isinstance(expr, sp.Indexed):
            typed_symbol = expr.base.label
            return typed_symbol.dtype.base_type
        elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
            # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
            result = create_type("bool")
            vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
            if vec_args:
                result = VectorType(result, width=vec_args[0].width)
            return result
        elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
            return get_type(expr.args[0])
        elif isinstance(expr, sp.Expr):
            expr: sp.Expr
            if expr.args:
                types = tuple(get_type(a) for a in expr.args)
                return collate_types(types)
            else:
                if expr.is_integer:
                    return create_type(default_int_type)
                else:
                    return create_type(default_float_type)
    
        raise NotImplementedError("Could not determine type for", expr, type(expr))
    
    
    class Type(sp.Basic):
        is_Atom = True
    
        def __new__(cls, *args, **kwargs):
            return sp.Basic.__new__(cls)
    
        def _sympystr(self, *args, **kwargs):
            return str(self)
    
    
    class BasicType(Type):
        @staticmethod
        def numpy_name_to_c(name):
            if name == 'float64':
                return 'double'
            elif name == 'float32':
                return 'float'
            elif name.startswith('int'):
                width = int(name[len("int"):])
                return "int%d_t" % (width,)
            elif name.startswith('uint'):
                width = int(name[len("uint"):])
                return "uint%d_t" % (width,)
            elif name == 'bool':
                return 'bool'
            else:
                raise NotImplementedError("Can map numpy to C name for %s" % (name,))
    
        def __init__(self, dtype, const=False):
            self.const = const
            if isinstance(dtype, Type):
                self._dtype = dtype.numpy_dtype
            else:
                self._dtype = np.dtype(dtype)
            assert self._dtype.fields is None, "Tried to initialize NativeType with a structured type"
            assert self._dtype.hasobject is False
            assert self._dtype.subdtype is None
    
        def __getnewargs__(self):
            return self.numpy_dtype, self.const
    
        @property
        def base_type(self):
            return None
    
        @property
        def numpy_dtype(self):
            return self._dtype
    
        @property
        def sympy_dtype(self):
            return getattr(sympy.codegen.ast, str(self.numpy_dtype))
    
        @property
        def item_size(self):
            return 1
    
        def is_int(self):
            return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint']
    
        def is_float(self):
            return self.numpy_dtype in np.sctypes['float']
    
        def is_uint(self):
            return self.numpy_dtype in np.sctypes['uint']
    
        def is_complex(self):
            return self.numpy_dtype in np.sctypes['complex']
    
        def is_other(self):
            return self.numpy_dtype in np.sctypes['others']
    
        @property
        def base_name(self):
            return BasicType.numpy_name_to_c(str(self._dtype))
    
        def __str__(self):
            result = BasicType.numpy_name_to_c(str(self._dtype))
            if self.const:
                result += " const"
            return result
    
        def __repr__(self):
            return str(self)
    
        def __eq__(self, other):
            if not isinstance(other, BasicType):
                return False
            else:
                return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
    
        def __hash__(self):
            return hash(str(self))
    
    
    class VectorType(Type):
        instruction_set = None
    
        def __init__(self, base_type, width=4):
            self._base_type = base_type
            self.width = width
    
        @property
        def base_type(self):
            return self._base_type
    
        @property
        def item_size(self):
            return self.width * self.base_type.item_size
    
        def __eq__(self, other):
            if not isinstance(other, VectorType):
                return False
            else:
                return (self.base_type, self.width) == (other.base_type, other.width)
    
        def __str__(self):
            if self.instruction_set is None:
                return "%s[%d]" % (self.base_type, self.width)
            else:
                if self.base_type == create_type("int64"):
                    return self.instruction_set['int']
                elif self.base_type == create_type("float64"):
                    return self.instruction_set['double']
                elif self.base_type == create_type("float32"):
                    return self.instruction_set['float']
                elif self.base_type == create_type("bool"):
                    return self.instruction_set['bool']
                else:
                    raise NotImplementedError()
    
        def __hash__(self):
            return hash((self.base_type, self.width))
    
        def __getnewargs__(self):
            return self._base_type, self.width
    
    
    class PointerType(Type):
        def __init__(self, base_type, const=False, restrict=True):
            self._base_type = base_type
            self.const = const
            self.restrict = restrict
    
        def __getnewargs__(self):
            return self.base_type, self.const, self.restrict
    
        @property
        def alias(self):
            return not self.restrict
    
        @property
        def base_type(self):
            return self._base_type
    
        @property
        def item_size(self):
            return self.base_type.item_size
    
        def __eq__(self, other):
            if not isinstance(other, PointerType):
                return False
            else:
                return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict)
    
        def __str__(self):
            components = [str(self.base_type), '*']
            if self.restrict:
                components.append('RESTRICT')
            if self.const:
                components.append("const")
            return " ".join(components)
    
        def __repr__(self):
            return str(self)
    
        def __hash__(self):
            return hash((self._base_type, self.const, self.restrict))
    
    
    class StructType:
        def __init__(self, numpy_type, const=False):
            self.const = const
            self._dtype = np.dtype(numpy_type)
    
        def __getnewargs__(self):
            return self.numpy_dtype, self.const
    
        @property
        def base_type(self):
            return None
    
        @property
        def numpy_dtype(self):
            return self._dtype
    
        @property
        def item_size(self):
            return self.numpy_dtype.itemsize
    
        def get_element_offset(self, element_name):
            return self.numpy_dtype.fields[element_name][1]
    
        def get_element_type(self, element_name):
            np_element_type = self.numpy_dtype.fields[element_name][0]
            return BasicType(np_element_type, self.const)
    
        def has_element(self, element_name):
            return element_name in self.numpy_dtype.fields
    
        def __eq__(self, other):
            if not isinstance(other, StructType):
                return False
            else:
                return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const)
    
        def __str__(self):
            # structs are handled byte-wise
            result = "uint8_t"
            if self.const:
                result += " const"
            return result
    
        def __repr__(self):
            return str(self)
    
        def __hash__(self):
            return hash((self.numpy_dtype, self.const))