diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 1cc62c168b9daf32e7ea52cf0e3d255f1b684801..4f67435bb34db3cdd7fd200f3e39da89ca0f39b8 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -12,6 +12,7 @@ from pystencils.cache import memorycache_if_hashable from pystencils.typing.types import BasicType, VectorType, PointerType, create_type from pystencils.typing.cast_functions import CastFunc, PointerArithmeticFunc from pystencils.typing.typed_sympy import TypedSymbol +from pystencils.utils import all_equal def typed_symbols(names, dtype, *args): @@ -33,14 +34,6 @@ def get_base_type(data_type): return data_type -def peel_off_type(dtype, type_to_peel_off): - # TODO: WTF is this??? DOCS!!! - # TODO: used only once.... can be a lambda there - while type(dtype) is type_to_peel_off: - dtype = dtype.base_type - return dtype - - ############################# This is basically our type system ######################################################## def result_type(*args: np.dtype): @@ -83,18 +76,25 @@ def collate_types(types: Sequence[Union[BasicType, VectorType]]): # # peel of vector types, if at least one vector type occurred the result will also be the vector type vector_type = [t for t in types if isinstance(t, VectorType)] - # if not all_equal(t.width for t in vector_type): - # raise ValueError("Collation failed because of vector types with different width") + if not all_equal(t.width for t in vector_type): + raise ValueError("Collation failed because of vector types with different width") + + # TODO: check if this is needed + # def peel_off_type(dtype, type_to_peel_off): + # while type(dtype) is type_to_peel_off: + # dtype = dtype.base_type + # return dtype # types = [peel_off_type(t, VectorType) for t in types] + types = [t.base_type if isinstance(t, VectorType) else t for t in types] + # now we should have a list of basic types - struct types are not yet supported assert all(type(t) is BasicType for t in types) result_numpy_type = result_type(*(t.numpy_dtype for t in types)) result = BasicType(result_numpy_type) if vector_type: - raise NotImplementedError("Vector type not implemented at the moment") - # result = VectorType(result, vector_type[0].width) + result = VectorType(result, vector_type[0].width) return result diff --git a/pystencils/utils.py b/pystencils/utils.py index dc8d35ee64dcfdb0ef6f9f687526fe3379ce8fbd..22d61d0bac6c402e10a7f48a07a55264ec4ddf27 100644 --- a/pystencils/utils.py +++ b/pystencils/utils.py @@ -1,5 +1,6 @@ import os import itertools +from itertools import groupby from collections import Counter from contextlib import contextmanager from tempfile import NamedTemporaryFile @@ -23,13 +24,13 @@ class DotDict(dict): self[key] = value -def all_equal(iterator): - iterator = iter(iterator) - try: - first = next(iterator) - except StopIteration: - return True - return all(first == rest for rest in iterator) +def all_equal(iterable): + """ + Returns ``True`` if all the elements are equal to each other. + Copied from: more-itertools 8.12.0 + """ + g = groupby(iterable) + return next(g, True) and not next(g, False) def recursive_dict_update(d, u): diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 478022d32b5d62cc6485d3d01bba32c8e34b0372..63feb7ed761738dc89d6e1fec631a864853af731 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -30,6 +30,8 @@ def test_vector_type_propagation(instruction_set=instruction_set): ast = ps.create_kernel(update_rule) vectorize(ast, instruction_set=instruction_set) + # ps.show_code(ast) + func = ast.compile() dst = np.zeros_like(arr) func(g=dst, f=arr) @@ -64,6 +66,8 @@ def test_aligned_and_nt_stores(instruction_set=instruction_set, openmp=False): assert ast.instruction_set[instruction].split('{')[0] in ps.get_code_str(ast) kernel = ast.compile() + # ps.show_code(ast) + dh.run_kernel(kernel) np.testing.assert_equal(np.sum(dh.cpu_arrays['f']), np.prod(domain_size))