diff --git a/pystencils/integer_functions.py b/pystencils/integer_functions.py index a6b9daf21c061b67fb42884c42f8512a22f5f7de..a0516ac3ab1693f928374539c9be9b9f1d2cc97d 100644 --- a/pystencils/integer_functions.py +++ b/pystencils/integer_functions.py @@ -27,6 +27,27 @@ class IntegerFunctionTwoArgsMixIn(sp.Function): raise ValueError("Integer functions can only be constructed with typed expressions") return super().__new__(cls, *args) +class IntegerFunctionOneArgMixIn(sp.Function): + is_integer = True + + def __new__(cls, arg1): + args = [] + + if isinstance(arg1, sp.Number) or isinstance(arg1, int): + args.append(cast_func(arg1, create_type("int"))) + elif isinstance(arg1, np.generic): + args.append(cast_func(arg1, arg1.dtype)) + else: + args.append(arg1) + + + try: + type = get_type_of_expression(arg1) + if not type.is_int(): + raise ValueError("Argument to integer function is not an int but " + str(type)) + except NotImplementedError: + raise ValueError("Integer functions can only be constructed with typed expressions") + return super().__new__(cls, *args) # noinspection PyPep8Naming class bitwise_xor(IntegerFunctionTwoArgsMixIn): @@ -59,11 +80,11 @@ class int_div(IntegerFunctionTwoArgsMixIn): # noinspection PyPep8Naming -class int_power_of_2(sp.Function): +class int_power_of_2(IntegerFunctionOneArgMixIn): pass # noinspection PyPep8Naming -class post_increment(sp.Function): +class post_increment(IntegerFunctionOneArgMixIn): pass