Skip to content
Snippets Groups Projects
Commit 9f9d301c authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Make vec_any/vec_all vectorization actually work

parent 706d0739
No related branches found
No related tags found
1 merge request!228Vectorization improvements
Pipeline #31035 passed
...@@ -85,5 +85,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q ...@@ -85,5 +85,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon', q
result['&'] = f'vand{q_reg}_u{bits[data_type]}' + '({0}, {1})' result['&'] = f'vand{q_reg}_u{bits[data_type]}' + '({0}, {1})'
result['|'] = f'vorr{q_reg}_u{bits[data_type]}' + '({0}, {1})' result['|'] = f'vorr{q_reg}_u{bits[data_type]}' + '({0}, {1})'
result['blendv'] = f'vbsl{q_reg}_f{bits[data_type]}' + '({2}, {1}, {0})' result['blendv'] = f'vbsl{q_reg}_f{bits[data_type]}' + '({2}, {1}, {0})'
result['any'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) > 0'
result['all'] = f'vaddlvq_u8(vreinterpretq_u8_u{bits[data_type]}({{0}})) == 16*0xff'
return result return result
...@@ -588,18 +588,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -588,18 +588,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) return self.instruction_set['rsqrt'].format(self._print(expr.args[0]))
else: else:
return f"({self._print(1 / sp.sqrt(expr.args[0]))})" return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, vec_any): elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
expr_type = get_type_of_expression(expr.args[0]) instr = 'any' if isinstance(expr, vec_any) else 'all'
if type(expr_type) is not VectorType:
return self._print(expr.args[0])
else:
return self.instruction_set['any'].format(self._print(expr.args[0]))
elif isinstance(expr, vec_all):
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
return self._print(expr.args[0]) return self._print(expr.args[0])
else: else:
return self.instruction_set['all'].format(self._print(expr.args[0])) if isinstance(expr.args[0], sp.Rel):
op = expr.args[0].rel_op
if (instr, op) in self.instruction_set:
return self.instruction_set[(instr, op)].format(*[self._print(a) for a in expr.args[0].args])
return self.instruction_set[instr].format(self._print(expr.args[0]))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
......
...@@ -41,6 +41,19 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'): ...@@ -41,6 +41,19 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
'&': 'and[0, 1]', '&': 'and[0, 1]',
'|': 'or[0, 1]', '|': 'or[0, 1]',
'blendv': 'sel[0, 1, 2]', 'blendv': 'sel[0, 1, 2]',
('any', '=='): 'any_eq[0, 1]',
('any', '!='): 'any_ne[0, 1]',
('any', '<='): 'any_le[0, 1]',
('any', '<'): 'any_lt[0, 1]',
('any', '>='): 'any_ge[0, 1]',
('any', '>'): 'any_gt[0, 1]',
('all', '=='): 'all_eq[0, 1]',
('all', '!='): 'all_ne[0, 1]',
('all', '<='): 'all_le[0, 1]',
('all', '<'): 'all_lt[0, 1]',
('all', '>='): 'all_ge[0, 1]',
('all', '>'): 'all_gt[0, 1]',
} }
bits = {'double': 64, bits = {'double': 64,
...@@ -74,4 +87,7 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'): ...@@ -74,4 +87,7 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
result['makeVecConstInt'] = '((' + result['int'] + '){{' + ", ".join(['{0}' for _ in range(intwidth)]) + '}})' result['makeVecConstInt'] = '((' + result['int'] + '){{' + ", ".join(['{0}' for _ in range(intwidth)]) + '}})'
result['makeVecInt'] = '((' + result['int'] + '){{{0}, {1}, {2}, {3}}})' result['makeVecInt'] = '((' + result['int'] + '){{{0}, {1}, {2}, {3}}})'
result['any'] = 'vec_any_ne({0}, ((' + result['bool'] + ') {{' + ", ".join(['0'] * width) + '}}))'
result['all'] = 'vec_all_ne({0}, ((' + result['bool'] + ') {{' + ", ".join(['0'] * width) + '}}))'
return result return result
...@@ -137,11 +137,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): ...@@ -137,11 +137,11 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
result['double'] = f"__m{bit_width}d" result['double'] = f"__m{bit_width}d"
result['float'] = f"__m{bit_width}" result['float'] = f"__m{bit_width}"
result['int'] = f"__m{bit_width}i" result['int'] = f"__m{bit_width}i"
result['bool'] = f"__m{bit_width}d" result['bool'] = result[data_type]
result['headers'] = headers[instruction_set] result['headers'] = headers[instruction_set]
result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0" result['any'] = f"{pre}_movemask_{suf}({{0}}) > 0"
result['all'] = f"{pre}_movemask_{suf}({{0}}) == 0xF" result['all'] = f"{pre}_movemask_{suf}({{0}}) == {hex(2**result['width']-1)}"
if instruction_set == 'avx512': if instruction_set == 'avx512':
size = 8 if data_type == 'double' else 16 size = 8 if data_type == 'double' else 16
......
...@@ -4,18 +4,19 @@ import pytest ...@@ -4,18 +4,19 @@ import pytest
import pystencils as ps import pystencils as ps
from pystencils.astnodes import Block, Conditional from pystencils.astnodes import Block, Conditional
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
@pytest.mark.skipif(not get_supported_instruction_sets(), reason='cannot detect CPU instruction set') @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.skipif('neon' in get_supported_instruction_sets(), reason='ARM does not have collective instructions') @pytest.mark.parametrize('dtype', ('float', 'double'))
@pytest.mark.xfail('vsx' in get_supported_instruction_sets(), reason='PPC collective instructions not implemented') def test_vec_any(instruction_set, dtype):
def test_vec_any(): width = get_vector_instruction_set(dtype, instruction_set)['width']
data_arr = np.zeros((15, 15)) data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32)
data_arr[3:9, 1] = 1.0 data_arr[3:9, 1:3*width-1] = 1.0
data = ps.fields("data: double[2D]", data=data_arr) data = ps.fields(f"data: {dtype}[2D]", data=data_arr)
c = [ c = [
ps.Assignment(sp.Symbol("t1"), vec_any(data.center() > 0.0)), ps.Assignment(sp.Symbol("t1"), vec_any(data.center() > 0.0)),
...@@ -23,25 +24,21 @@ def test_vec_any(): ...@@ -23,25 +24,21 @@ def test_vec_any():
ps.Assignment(data.center(), 2.0) ps.Assignment(data.center(), 2.0)
])) ]))
] ]
instruction_set = get_supported_instruction_sets()[-1]
ast = ps.create_kernel(c, target='cpu', ast = ps.create_kernel(c, target='cpu',
cpu_vectorize_info={'instruction_set': instruction_set}) cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile() kernel = ast.compile()
kernel(data=data_arr) kernel(data=data_arr)
np.testing.assert_equal(data_arr[3:9, :3*width], 2.0)
width = ast.instruction_set['width']
np.testing.assert_equal(data_arr[3:9, 0:width], 2.0) @pytest.mark.parametrize('instruction_set', supported_instruction_sets)
@pytest.mark.parametrize('dtype', ('float', 'double'))
def test_vec_all(instruction_set, dtype):
width = get_vector_instruction_set(dtype, instruction_set)['width']
data_arr = np.zeros((4*width, 4*width), dtype=np.float64 if dtype == 'double' else np.float32)
data_arr[3:9, 1:3*width-1] = 1.0
@pytest.mark.skipif(not get_supported_instruction_sets(), reason='cannot detect CPU instruction set') data = ps.fields(f"data: {dtype}[2D]", data=data_arr)
@pytest.mark.skipif('neon' in get_supported_instruction_sets(), reason='ARM does not have collective instructions')
@pytest.mark.xfail('vsx' in get_supported_instruction_sets(), reason='PPC collective instructions not implemented')
def test_vec_all():
data_arr = np.zeros((15, 15))
data_arr[3:9, 2:7] = 1.0
data = ps.fields("data: double[2D]", data=data_arr)
c = [ c = [
Conditional(vec_all(data.center() > 0.0), Block([ Conditional(vec_all(data.center() > 0.0), Block([
...@@ -49,14 +46,17 @@ def test_vec_all(): ...@@ -49,14 +46,17 @@ def test_vec_all():
])) ]))
] ]
ast = ps.create_kernel(c, target='cpu', ast = ps.create_kernel(c, target='cpu',
cpu_vectorize_info={'instruction_set': get_supported_instruction_sets()[-1]}) cpu_vectorize_info={'instruction_set': instruction_set})
kernel = ast.compile() kernel = ast.compile()
before = data_arr.copy()
kernel(data=data_arr) kernel(data=data_arr)
np.testing.assert_equal(data_arr, before) np.testing.assert_equal(data_arr[3:9, :1], 0.0)
np.testing.assert_equal(data_arr[3:9, 1:width], 1.0)
np.testing.assert_equal(data_arr[3:9, width:2*width], 2.0)
np.testing.assert_equal(data_arr[3:9, 2*width:3*width-1], 1.0)
np.testing.assert_equal(data_arr[3:9, 3*width-1:], 0.0)
@pytest.mark.skipif(not get_supported_instruction_sets(), reason='cannot detect CPU instruction set') @pytest.mark.skipif(not supported_instruction_sets, reason='cannot detect CPU instruction set')
def test_boolean_before_loop(): def test_boolean_before_loop():
t1, t2 = sp.symbols('t1, t2') t1, t2 = sp.symbols('t1, t2')
f_arr = np.ones((10, 10)) f_arr = np.ones((10, 10))
...@@ -68,7 +68,7 @@ def test_boolean_before_loop(): ...@@ -68,7 +68,7 @@ def test_boolean_before_loop():
ps.Assignment(g[0, 0], ps.Assignment(g[0, 0],
sp.Piecewise((f[0, 0], t1), (42, True))) sp.Piecewise((f[0, 0], t1), (42, True)))
] ]
ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': get_supported_instruction_sets()[-1]}) ast = ps.create_kernel(a, cpu_vectorize_info={'instruction_set': supported_instruction_sets[-1]})
kernel = ast.compile() kernel = ast.compile()
kernel(f=f_arr, g=g_arr, t2=1.0) kernel(f=f_arr, g=g_arr, t2=1.0)
print(g) print(g)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment