Skip to content
Snippets Groups Projects
Commit 2f1c1194 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix and improved bit mask support

parent 69286c9b
Branches
Tags
1 merge request!292Rebase of pystencils Type System
import sympy as sp
from pystencils.typing import get_type_of_expression
# from pystencils.typing import get_type_of_expression
# noinspection PyPep8Naming
......@@ -22,13 +22,14 @@ class flag_cond(sp.Function):
def __new__(cls, flag_bit, mask_expression, *expressions):
flag_dtype = get_type_of_expression(flag_bit)
if not flag_dtype.is_int():
raise ValueError('Argument flag_bit must be of integer type.')
mask_dtype = get_type_of_expression(mask_expression)
if not mask_dtype.is_int():
raise ValueError('Argument mask_expression must be of integer type.')
# TODO reintroduce checking
# flag_dtype = get_type_of_expression(flag_bit)
# if not flag_dtype.is_int():
# raise ValueError('Argument flag_bit must be of integer type.')
#
# mask_dtype = get_type_of_expression(mask_expression)
# if not mask_dtype.is_int():
# raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions)
......
......@@ -175,7 +175,10 @@ class TypeAdder:
raise NotImplementedError('integer_functions')
elif isinstance(expr, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
raise NotImplementedError('flag_cond')
args_types = [self.figure_out_type(a) for a in (expr.args[i] for i in range(2, len(expr.args)))]
collated_type = collate_types([t for _, t in args_types])
new_expressions = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
return expr.func(expr.args[0], expr.args[1], *new_expressions), collated_type
#elif isinstance(expr, sp.Mul):
# raise NotImplementedError('sp.Mul')
# # TODO can we ignore this and move it to general expr handling, i.e. removing Mul?
......
import pytest
import numpy as np
import pystencils as ps
from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond
def test_flag_condition():
@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
def test_flag_condition(mask_type):
f_arr = np.zeros((2, 2, 2), dtype=np.float64)
mask_arr = np.zeros((2, 2), dtype=np.uint64)
mask_arr = np.zeros((2, 2), dtype=mask_type)
mask_arr[0, 1] = (1 << 3)
mask_arr[1, 0] = (1 << 5)
......@@ -16,7 +20,7 @@ def test_flag_condition():
v1 = 42.3
v2 = 39.7
v3 = 119.87
v3 = 119
assignments = [
Assignment(f(0), flag_cond(3, mask(0), v1)),
......@@ -25,6 +29,8 @@ def test_flag_condition():
kernel = create_kernel(assignments).compile()
kernel(f=f_arr, mask=mask_arr)
code = ps.get_code_str(kernel)
assert '119.0' in code
reference = np.zeros((2, 2, 2), dtype=np.float64)
reference[0, 1, 0] = v1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment