diff --git a/pystencils/bit_masks.py b/pystencils/bit_masks.py index 73c18688cc7d34bc23caed66528bae3f148dba65..ad4e967f0131dde5347c20de467afbab4892e149 100644 --- a/pystencils/bit_masks.py +++ b/pystencils/bit_masks.py @@ -1,5 +1,5 @@ import sympy as sp -from pystencils.typing import get_type_of_expression +# from pystencils.typing import get_type_of_expression # noinspection PyPep8Naming @@ -22,13 +22,14 @@ class flag_cond(sp.Function): def __new__(cls, flag_bit, mask_expression, *expressions): - flag_dtype = get_type_of_expression(flag_bit) - if not flag_dtype.is_int(): - raise ValueError('Argument flag_bit must be of integer type.') - - mask_dtype = get_type_of_expression(mask_expression) - if not mask_dtype.is_int(): - raise ValueError('Argument mask_expression must be of integer type.') + # TODO reintroduce checking + # flag_dtype = get_type_of_expression(flag_bit) + # if not flag_dtype.is_int(): + # raise ValueError('Argument flag_bit must be of integer type.') + # + # mask_dtype = get_type_of_expression(mask_expression) + # if not mask_dtype.is_int(): + # raise ValueError('Argument mask_expression must be of integer type.') return super().__new__(cls, flag_bit, mask_expression, *expressions) diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index b620c9c7e1257881522db57b8a1f5aad138a8485..ef81b529d25d68613f47afd8913f6f1eb3c59dc8 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -175,7 +175,10 @@ class TypeAdder: raise NotImplementedError('integer_functions') elif isinstance(expr, flag_cond): # do not process the arguments to the bit shift - they must remain integers - raise NotImplementedError('flag_cond') + args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))] + collated_type = collate_types([t for _, t in args_types]) + new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] + return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type #elif isinstance(expr, sp.Mul): # raise NotImplementedError('sp.Mul') # # TODO can we ignore this and move it to general expr handling, i.e. removing Mul? diff --git a/pystencils_tests/test_bit_masks.py b/pystencils_tests/test_bit_masks.py index 57371976f416abdf52274852666860c3c92dcdf2..2bc5bc7a24e03972f4635a463c3c64fb3c785c36 100644 --- a/pystencils_tests/test_bit_masks.py +++ b/pystencils_tests/test_bit_masks.py @@ -1,10 +1,12 @@ import numpy as np +import pystencils as ps from pystencils import Field, Assignment, create_kernel from pystencils.bit_masks import flag_cond def test_flag_condition(): f_arr = np.zeros((2, 2, 2), dtype=np.float64) + # TODO different uints mask_arr = np.zeros((2, 2), dtype=np.uint64) mask_arr[0, 1] = (1 << 3) @@ -25,6 +27,7 @@ def test_flag_condition(): kernel = create_kernel(assignments).compile() kernel(f=f_arr, mask=mask_arr) + code = ps.get_code_str(kernel) reference = np.zeros((2, 2, 2), dtype=np.float64) reference[0, 1, 0] = v1