From ff0e41ee0aab9764ee73b210cdca7e7b8043e2bd Mon Sep 17 00:00:00 2001 From: Michael Kuron <m.kuron@gmx.de> Date: Sat, 24 Aug 2024 11:37:30 +0200 Subject: [PATCH] Fused-multiply-add vectorization --- .../backends/arm_instruction_sets.py | 6 + src/pystencils/backends/cbackend.py | 5 +- .../backends/ppc_instruction_sets.py | 4 + .../backends/riscv_instruction_sets.py | 4 + .../backends/x86_instruction_sets.py | 6 + src/pystencils/cpu/vectorization.py | 5 +- src/pystencils/fast_approximation.py | 143 +++++++++++++++++- tests/test_vec_fma.py | 127 ++++++++++++++++ 8 files changed, 297 insertions(+), 3 deletions(-) create mode 100644 tests/test_vec_fma.py diff --git a/src/pystencils/backends/arm_instruction_sets.py b/src/pystencils/backends/arm_instruction_sets.py index 227224f4e..0a1d6b872 100644 --- a/src/pystencils/backends/arm_instruction_sets.py +++ b/src/pystencils/backends/arm_instruction_sets.py @@ -35,6 +35,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): '-': 'sub[0, 1]', '*': 'mul[0, 1]', '/': 'div[0, 1]', + '*+': 'fma[2, 1, 0]', # 2*1+0 + '-*+': 'fms[2, 1, 0]', # -(2*1+0) 'sqrt': 'sqrt[0]', 'loadU': 'ld1[0]', @@ -65,6 +67,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['bytes'] = bitwidth // 8 if instruction_set.startswith('sve') or instruction_set == 'sme': base_names['stream'] = 'stnt1[0, 1]' + base_names['*+'] = 'mad[2, 1, 0]' + base_names['*-'] = 'msb[2, 1, 0]' + base_names['-*+'] = 'nmad[2, 1, 0]' + base_names['-*-'] = 'nmsb[2, 1, 0]' prefix = 'sv' suffix = f'_f{bits[data_type]}' elif instruction_set == 'neon': diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py index 657f60d2f..c00cf520a 100644 --- a/src/pystencils/backends/cbackend.py +++ b/src/pystencils/backends/cbackend.py @@ -17,7 +17,7 @@ from pystencils.typing import ( PointerType, VectorType, CastFunc, create_type, get_type_of_expression, ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) from pystencils.enums import Backend -from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt +from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt, Fma from pystencils.functions import DivFunc, AddressOf from pystencils.integer_functions import ( bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, @@ -728,6 +728,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): raise ValueError("fast_sqrt is only supported for Taget.GPU") elif isinstance(expr, fast_inv_sqrt): raise ValueError("fast_inv_sqrt is only supported for Taget.GPU") + elif isinstance(expr, Fma): + return self.instruction_set[expr.instruction].format(self._print(expr.args[0]), self._print(expr.args[1]), + self._print(expr.args[2]), **self._kwargs) elif isinstance(expr, vec_any) or isinstance(expr, vec_all): instr = 'any' if isinstance(expr, vec_any) else 'all' expr_type = get_type_of_expression(expr.args[0]) diff --git a/src/pystencils/backends/ppc_instruction_sets.py b/src/pystencils/backends/ppc_instruction_sets.py index d2d319976..79df13c7b 100644 --- a/src/pystencils/backends/ppc_instruction_sets.py +++ b/src/pystencils/backends/ppc_instruction_sets.py @@ -22,6 +22,10 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'): '-': 'sub[0, 1]', '*': 'mul[0, 1]', '/': 'div[0, 1]', + '*+': 'madd[0, 1, 2]', + '*-': 'msub[0, 1, 2]', + '-*+': 'nmsub[0, 1, 2]', + '-*-': 'nmadd[0, 1, 2]', 'sqrt': 'sqrt[0]', 'rsqrt': 'rsqrte[0]', # rsqrt is available too, but not on Clang diff --git a/src/pystencils/backends/riscv_instruction_sets.py b/src/pystencils/backends/riscv_instruction_sets.py index 27f631e7f..20aff7921 100644 --- a/src/pystencils/backends/riscv_instruction_sets.py +++ b/src/pystencils/backends/riscv_instruction_sets.py @@ -27,6 +27,10 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): '-': 'fsub_vv[0, 1]', '*': 'fmul_vv[0, 1]', '/': 'fdiv_vv[0, 1]', + '*+': 'fmadd[2, 0, 1]', + '*-': 'fmsub[2, 0, 1]', + '-*+': 'fnmadd[2, 0, 1]', + '-*-': 'fnmsub[2, 0, 1]', 'sqrt': 'fsqrt_v[0]', 'loadU': f'le{bits[data_type]}_v[0]', diff --git a/src/pystencils/backends/x86_instruction_sets.py b/src/pystencils/backends/x86_instruction_sets.py index 7f05a403d..34727ec64 100644 --- a/src/pystencils/backends/x86_instruction_sets.py +++ b/src/pystencils/backends/x86_instruction_sets.py @@ -61,6 +61,12 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'): 'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]', } + if instruction_set.startswith('avx'): + base_names['*+'] = 'fmadd[0, 1, 2]' # 0*1+2 + base_names['*-'] = 'fmsub[0, 1, 2]' # 0*1-2 + base_names['-*+'] = 'fnmadd[0, 1, 2]' # -0*1+2 + base_names['-*-'] = 'fnmsub[0, 1, 2]' # -0*1-2 + for comparison_op, constant in comparisons.items(): base_names[comparison_op] = f'cmp[0, 1, {constant}]' diff --git a/src/pystencils/cpu/vectorization.py b/src/pystencils/cpu/vectorization.py index 872f0b3c4..32cb70974 100644 --- a/src/pystencils/cpu/vectorization.py +++ b/src/pystencils/cpu/vectorization.py @@ -12,6 +12,7 @@ from pystencils.typing import (BasicType, PointerType, TypedSymbol, VectorType, from pystencils.functions import DivFunc from pystencils.field import Field from pystencils.integer_functions import modulo_ceil, modulo_floor +from pystencils.fast_approximation import insert_fma, fmadd, fmsub, fnmadd, fnmsub from pystencils.sympyextensions import fast_subs from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one @@ -127,6 +128,8 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best', vector_is = get_vector_instruction_set(default_float_type, instruction_set=instruction_set) kernel_ast.instruction_set = vector_is + kernel_ast = insert_fma(kernel_ast, vector_is.keys()) + if nontemporal and 'cachelineZero' in vector_is: kernel_ast.use_all_written_field_sizes = True strided = 'storeS' in vector_is and 'loadS' in vector_is @@ -264,7 +267,7 @@ def mask_conditionals(loop_body): def insert_vector_casts(ast_node, instruction_set, default_float_type='double'): """Inserts necessary casts from scalar values to vector values.""" - handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs) + handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs, fmadd, fmsub, fnmadd, fnmsub) def is_scalar(expr) -> bool: if hasattr(expr, "dtype"): diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/fast_approximation.py index ab0dc5974..de3dc62c3 100644 --- a/src/pystencils/fast_approximation.py +++ b/src/pystencils/fast_approximation.py @@ -2,7 +2,7 @@ from typing import List, Union import sympy as sp -from pystencils.astnodes import Node +from pystencils.astnodes import Node, ResolvedFieldAccess, Block from pystencils.simp import AssignmentCollection from pystencils.assignment import Assignment @@ -31,6 +31,42 @@ class fast_inv_sqrt(sp.Function): nargs = (1, ) +class Fma(sp.Function): + nargs = (3, ) + + +# noinspection PyPep8Naming +class fmadd(Fma): + """ + Produces FMA instructions for a*b+c + """ + instruction = '*+' + + +# noinspection PyPep8Naming +class fmsub(Fma): + """ + Produces FMA instructions for a*b-c + """ + instruction = '*-' + + +# noinspection PyPep8Naming +class fnmadd(Fma): + """ + Produces FMA instructions for -a*b+c + """ + instruction = '-*+' + + +# noinspection PyPep8Naming +class fnmsub(Fma): + """ + Produces FMA instructions for -a*b-c + """ + instruction = '-*-' + + def _run(term, visitor): if isinstance(term, AssignmentCollection): new_main_assignments = _run(term.main_assignments, visitor) @@ -82,3 +118,108 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti return expr.func(*new_args) if new_args else expr return _run(term, visit) + + +def insert_fma(term, operators): + if '*+' not in operators: + return term + + def flatten(expr): + if isinstance(expr, sp.Add): + new_args = [] + for arg in expr.args: + if arg.func == sp.Add: + new_args += [flatten(a) for a in arg.args] + else: + new_args.append(flatten(arg)) + return sp.Add(*new_args) + elif isinstance(expr, sp.Mul): + new_args = [] + for arg in expr.args: + if arg.func == sp.Mul: + new_args += [flatten(a) for a in arg.args] + else: + new_args.append(flatten(arg)) + return sp.Mul(*new_args) + return expr + + def visit(expr): + if isinstance(expr, ResolvedFieldAccess): + return expr + elif hasattr(expr, 'body'): + old_parent = expr.body.parent if hasattr(expr.body, 'parent') else None + expr.body = visit(expr.body) + if old_parent is not None: + expr.body.parent = old_parent + return expr + elif isinstance(expr, Block): + return Block([visit(a) for a in expr.args]) + elif expr.func == sp.Add: + expr = flatten(expr) + summands = list(expr.args) + if '-*+' in operators: + for summand in expr.args: + if summand.func == sp.Mul and len(summand.args) >= 3 and -1 in summand.args: + summands.remove(summand) + factors = list(summand.args) + factors.remove(-1) + factors = [visit(f) for f in factors] + if not ('-*-' in operators and all(s.func == sp.Mul and -1 in s.args for s in summands)): + summands = [visit(s) for s in summands] + return sp.Add(fnmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:]) + summands = list(expr.args) + if '-*-' in operators: + negative = [s for s in summands if s.func == sp.Mul and -1 in s.args] + positive = [] + if len(negative) > 1: + positive = [] + for summand in negative: + summands.remove(summand) + factors = list(summand.args) + factors.remove(-1) + positive.append(sp.Mul(*factors)) + for summand in positive: + if summand.func == sp.Mul and len(summand.args) >= 2: + positive.remove(summand) + factors = list(summand.args) + positive = [visit(s) for s in positive] + summands = [visit(f) for f in summands] + return sp.Add(fnmsub(factors[0], sp.Mul(*factors[1:]), sp.Add(*positive[:2])), + -sp.Add(*positive[2:]), *summands) + summands = list(expr.args) + if '*-' in operators: + for summand in summands: + if summand.func == sp.Mul and len(summand.args) >= 2 and -1 not in summand.args: + summands.remove(summand) + factors = list(summand.args) + for summand in summands: + if summand.func == sp.Mul and -1 in summand.args: + summands.remove(summand) + subfactors = list(summand.args) + subfactors.remove(-1) + factors = [visit(f) for f in factors] + summands = [visit(s) for s in summands] + subfactors = [visit(f) for f in subfactors] + return sp.Add(fmsub(factors[0], sp.Mul(*factors[1:]), sp.Mul(*subfactors)), *summands) + summands = list(expr.args) + for summand in summands: + if summand.func == sp.Mul and len(summand.args) >= 2: + summands.remove(summand) + factors = list(summand.args) + factors = [visit(f) for f in factors] + summands = [visit(s) for s in summands] + return sp.Add(fmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:]) + return expr + elif expr.func == sp.Mul and -1 in expr.args: + expr = flatten(expr) + factors = list(expr.args) + factors.remove(-1) + factors = [visit(f) for f in factors] + if '-*+' in operators: + return fnmadd(factors[0], sp.Mul(*factors[1:]), 0) + elif '-*-' in operators: + return fnmsub(factors[0], sp.Mul(*factors[1:]), 0) + new_args = [visit(a) for a in expr.args] + return expr.func(*new_args) if new_args else expr + + return _run(term, visit) diff --git a/tests/test_vec_fma.py b/tests/test_vec_fma.py new file mode 100644 index 000000000..357c37de5 --- /dev/null +++ b/tests/test_vec_fma.py @@ -0,0 +1,127 @@ +import pytest +import pystencils as ps +import numpy as np + +from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_vector_instruction_set) +from pystencils.fast_approximation import fmadd, fmsub, fnmadd, fnmsub, Fma + +supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else [] + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_fmadd(instruction_set, dtype): + da = 2 * np.ones((128, 128), dtype=dtype) + db = 3 * np.ones((128, 128), dtype=dtype) + dc = 5 * np.ones((128, 128), dtype=dtype) + dd = np.empty((128, 128), dtype=dtype) + + a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) + update_rule = [ps.Assignment(d.center(), a.center() * b.center() + c.center())] + + config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + if '*+' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fmadd]), "No FMAs found in AST" + + func = ast.compile() + func(a=da, b=db, c=dc, d=dd) + np.testing.assert_equal(dd, da * db + dc) + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_fmsub(instruction_set, dtype): + da = 2 * np.ones((128, 128), dtype=dtype) + db = 3 * np.ones((128, 128), dtype=dtype) + dc = 5 * np.ones((128, 128), dtype=dtype) + dd = np.empty((128, 128), dtype=dtype) + + a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) + update_rule = [ps.Assignment(d.center(), a.center() * b.center() - c.center())] + + config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + if '*-' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fmsub]), "No FMAs found in AST" + elif '*+' in get_vector_instruction_set(dtype, instruction_set): + assert set([fmsub]), "No FMAs found in AST" + + func = ast.compile() + func(a=da, b=db, c=dc, d=dd) + np.testing.assert_equal(dd, da * db - dc) + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_fnmadd(instruction_set, dtype): + da = 2 * np.ones((128, 128), dtype=dtype) + db = 3 * np.ones((128, 128), dtype=dtype) + dc = 5 * np.ones((128, 128), dtype=dtype) + dd = np.empty((128, 128), dtype=dtype) + + a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) + update_rule = [ps.Assignment(d.center(), -a.center() * b.center() + c.center())] + + config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + if '-*+' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fnmadd]), "No FMAs found in AST" + elif '*+' in get_vector_instruction_set(dtype, instruction_set): + assert set([fmsub]), "No FMAs found in AST" + + func = ast.compile() + func(a=da, b=db, c=dc, d=dd) + np.testing.assert_equal(dd, -da * db + dc) + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_fnmsub(instruction_set, dtype): + da = 2 * np.ones((128, 128), dtype=dtype) + db = 3 * np.ones((128, 128), dtype=dtype) + dc = 5 * np.ones((128, 128), dtype=dtype) + dd = np.empty((128, 128), dtype=dtype) + + a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd) + update_rule = [ps.Assignment(d.center(), -a.center() * b.center() - c.center())] + + config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + if '-*-' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fnmsub]), "No FMAs found in AST" + elif '*+' in get_vector_instruction_set(dtype, instruction_set): + assert set([fmsub]), "No FMAs found in AST" + + func = ast.compile() + func(a=da, b=db, c=dc, d=dd) + np.testing.assert_equal(dd, -da * db - dc) + + +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', supported_instruction_sets) +def test_fnm(instruction_set, dtype): + da = 2 * np.ones((128, 128), dtype=dtype) + db = 3 * np.ones((128, 128), dtype=dtype) + dd = np.empty((128, 128), dtype=dtype) + + a, b, d = ps.fields(a=da, b=db, d=dd) + update_rule = [ps.Assignment(d.center(), -a.center() * b.center())] + + config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + + if '-*+' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fnmadd]), "No FMAs found in AST" + elif '-*-' in get_vector_instruction_set(dtype, instruction_set): + assert set(e.func for e in ast.atoms(Fma)) == set([fnmsub]), "No FMAs found in AST" + else: + assert set(e.func for e in ast.atoms(Fma)) == set(), "Unexpected FMAs found in AST" + + func = ast.compile() + func(a=da, b=db, d=dd) + np.testing.assert_equal(dd, -da * db) -- GitLab