Skip to content
Snippets Groups Projects
Commit 9f501a36 authored by Frederik Hennig's avatar Frederik Hennig Committed by Markus Holzer
Browse files

Bit Flag Conditional

parent 699144dd
No related branches found
No related tags found
No related merge requests found
import sympy as sp
from pystencils.data_types import get_type_of_expression
# noinspection PyPep8Naming
class flag_cond(sp.Function):
"""Evaluates a flag condition on a bit mask, and returns the value of one of two expressions,
depending on whether the flag is set.
Three argument version:
```
flag_cond(flag_bit, mask, expr) = expr if (flag_bit is set in mask) else 0
```
Four argument version:
```
flag_cond(flag_bit, mask, expr_then, expr_else) = expr_then if (flag_bit is set in mask) else expr_else
```
"""
nargs = (3, 4)
def __new__(cls, flag_bit, mask_expression, *expressions):
flag_dtype = get_type_of_expression(flag_bit)
if not flag_dtype.is_int():
raise ValueError('Argument flag_bit must be of integer type.')
mask_dtype = get_type_of_expression(mask_expression)
if not mask_dtype.is_int():
raise ValueError('Argument mask_expression must be of integer type.')
return super().__new__(cls, flag_bit, mask_expression, *expressions)
def to_c(self, print_func):
flag_bit = self.args[0]
mask = self.args[1]
then_expression = self.args[2]
flag_bit_code = print_func(flag_bit)
mask_code = print_func(mask)
then_code = print_func(then_expression)
code = f"(({mask_code}) >> ({flag_bit_code}) & 1) * ({then_code})"
if len(self.args) > 3:
else_expression = self.args[3]
else_code = print_func(else_expression)
code += f" + (({mask_code}) >> ({flag_bit_code}) ^ 1) * ({else_code})"
return code
......@@ -21,6 +21,7 @@ from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
from pystencils.bit_masks import flag_cond
class NestedScopes:
......@@ -876,6 +877,10 @@ class KernelConstraintsCheck:
else cast_func(a, arg_type)
for a in new_args]
return rhs.func(*new_args)
elif isinstance(rhs, flag_cond):
# do not process the arguments to the bit shift - they must remain integers
processed_args = (self.process_expression(a) for a in rhs.args[2:])
return flag_cond(rhs.args[0], rhs.args[1], *processed_args)
elif isinstance(rhs, sp.Mul):
new_args = [
self.process_expression(arg, type_constants)
......
import numpy as np
import sympy as sp
from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond
from pystencils import TypedSymbol
def test_flag_condition():
f_arr = np.zeros((2,2,2), dtype=np.float64)
mask_arr = np.zeros((2,2), dtype=np.uint64)
mask_arr[0,1] = (1<<3)
mask_arr[1,0] = (1<<5)
mask_arr[1,1] = (1<<3) + (1 << 5)
f = Field.create_from_numpy_array('f', f_arr, index_dimensions=1)
mask = Field.create_from_numpy_array('mask', mask_arr)
v1 = 42.3
v2 = 39.7
v3 = 119.87
assignments = [
Assignment(f(0), flag_cond(3, mask(0), v1)),
Assignment(f(1), flag_cond(5, mask(0), v2, v3))
]
kernel = create_kernel(assignments).compile()
kernel(f=f_arr, mask=mask_arr)
reference = np.zeros((2,2,2), dtype=np.float64)
reference[0,1,0] = v1
reference[1,1,0] = v1
reference[0,0,1] = v3
reference[0,1,1] = v3
reference[1,0,1] = v2
reference[1,1,1] = v2
np.testing.assert_array_equal(f_arr, reference)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment