diff --git a/src/pystencils/backends/arm_instruction_sets.py b/src/pystencils/backends/arm_instruction_sets.py
index 227224f4e65460a291bd3a6cd3309ed3525072fa..0a1d6b872545c453d1279b48560eb6ac041e0221 100644
--- a/src/pystencils/backends/arm_instruction_sets.py
+++ b/src/pystencils/backends/arm_instruction_sets.py
@@ -35,6 +35,8 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         '-': 'sub[0, 1]',
         '*': 'mul[0, 1]',
         '/': 'div[0, 1]',
+        '*+': 'fma[2, 1, 0]',  # 2*1+0
+        '-*+': 'fms[2, 1, 0]',  # -(2*1+0)
         'sqrt': 'sqrt[0]',
 
         'loadU': 'ld1[0]',
@@ -65,6 +67,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
         result['bytes'] = bitwidth // 8
     if instruction_set.startswith('sve') or instruction_set == 'sme':
         base_names['stream'] = 'stnt1[0, 1]'
+        base_names['*+'] = 'mad[2, 1, 0]'
+        base_names['*-'] = 'msb[2, 1, 0]'
+        base_names['-*+'] = 'nmad[2, 1, 0]'
+        base_names['-*-'] = 'nmsb[2, 1, 0]'
         prefix = 'sv'
         suffix = f'_f{bits[data_type]}' 
     elif instruction_set == 'neon':
diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py
index 657f60d2f16f14a20f81ebfc77414eb31ba0236a..c00cf520a8329ae869bdbdec6cb4d48763c9ae0b 100644
--- a/src/pystencils/backends/cbackend.py
+++ b/src/pystencils/backends/cbackend.py
@@ -17,7 +17,7 @@ from pystencils.typing import (
     PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
     ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
 from pystencils.enums import Backend
-from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
+from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt, Fma
 from pystencils.functions import DivFunc, AddressOf
 from pystencils.integer_functions import (
     bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
@@ -728,6 +728,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
             raise ValueError("fast_sqrt is only supported for Taget.GPU")
         elif isinstance(expr, fast_inv_sqrt):
             raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
+        elif isinstance(expr, Fma):
+            return self.instruction_set[expr.instruction].format(self._print(expr.args[0]), self._print(expr.args[1]),
+                                                                 self._print(expr.args[2]), **self._kwargs)
         elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
             instr = 'any' if isinstance(expr, vec_any) else 'all'
             expr_type = get_type_of_expression(expr.args[0])
diff --git a/src/pystencils/backends/ppc_instruction_sets.py b/src/pystencils/backends/ppc_instruction_sets.py
index d2d3199769363459af4c67cf8a59030531c416a8..79df13c7b14710da303c551ebdb7c724d02748f3 100644
--- a/src/pystencils/backends/ppc_instruction_sets.py
+++ b/src/pystencils/backends/ppc_instruction_sets.py
@@ -22,6 +22,10 @@ def get_vector_instruction_set_ppc(data_type='double', instruction_set='vsx'):
         '-': 'sub[0, 1]',
         '*': 'mul[0, 1]',
         '/': 'div[0, 1]',
+        '*+': 'madd[0, 1, 2]',
+        '*-': 'msub[0, 1, 2]',
+        '-*+': 'nmsub[0, 1, 2]',
+        '-*-': 'nmadd[0, 1, 2]',
         'sqrt': 'sqrt[0]',
         'rsqrt': 'rsqrte[0]',  # rsqrt is available too, but not on Clang
 
diff --git a/src/pystencils/backends/riscv_instruction_sets.py b/src/pystencils/backends/riscv_instruction_sets.py
index 27f631e7f92d25e366bc767c759697ac898f3308..20aff79215a84bfb511cdd2cfc76c933f271c127 100644
--- a/src/pystencils/backends/riscv_instruction_sets.py
+++ b/src/pystencils/backends/riscv_instruction_sets.py
@@ -27,6 +27,10 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
         '-': 'fsub_vv[0, 1]',
         '*': 'fmul_vv[0, 1]',
         '/': 'fdiv_vv[0, 1]',
+        '*+': 'fmadd[2, 0, 1]',
+        '*-': 'fmsub[2, 0, 1]',
+        '-*+': 'fnmadd[2, 0, 1]',
+        '-*-': 'fnmsub[2, 0, 1]',
         'sqrt': 'fsqrt_v[0]',
 
         'loadU': f'le{bits[data_type]}_v[0]',
diff --git a/src/pystencils/backends/x86_instruction_sets.py b/src/pystencils/backends/x86_instruction_sets.py
index 7f05a403d45a8ca890a78b7db92ff147d30dd7dc..34727ec6414f4b0efa3ec1bc66ac9579b58c58e4 100644
--- a/src/pystencils/backends/x86_instruction_sets.py
+++ b/src/pystencils/backends/x86_instruction_sets.py
@@ -61,6 +61,12 @@ def get_vector_instruction_set_x86(data_type='double', instruction_set='avx'):
         'maskStoreU': 'mask_storeu[0, 2, 1]' if instruction_set.startswith('avx512') else 'maskstore[0, 2, 1]',
     }
 
+    if instruction_set.startswith('avx'):
+        base_names['*+'] = 'fmadd[0, 1, 2]'  # 0*1+2
+        base_names['*-'] = 'fmsub[0, 1, 2]'  # 0*1-2
+        base_names['-*+'] = 'fnmadd[0, 1, 2]'  # -0*1+2
+        base_names['-*-'] = 'fnmsub[0, 1, 2]'  # -0*1-2
+
     for comparison_op, constant in comparisons.items():
         base_names[comparison_op] = f'cmp[0, 1, {constant}]'
 
diff --git a/src/pystencils/cpu/vectorization.py b/src/pystencils/cpu/vectorization.py
index 872f0b3c45983a38b5b1be8fd7f425eb422b570d..32cb70974fb9b0b71500003b6e5eb9826332c295 100644
--- a/src/pystencils/cpu/vectorization.py
+++ b/src/pystencils/cpu/vectorization.py
@@ -12,6 +12,7 @@ from pystencils.typing import (BasicType, PointerType, TypedSymbol, VectorType,
 from pystencils.functions import DivFunc
 from pystencils.field import Field
 from pystencils.integer_functions import modulo_ceil, modulo_floor
+from pystencils.fast_approximation import insert_fma, fmadd, fmsub, fnmadd, fnmsub
 from pystencils.sympyextensions import fast_subs
 from pystencils.transformations import cut_loop, filtered_tree_iteration, replace_inner_stride_with_one
 
@@ -127,6 +128,8 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'best',
     vector_is = get_vector_instruction_set(default_float_type, instruction_set=instruction_set)
     kernel_ast.instruction_set = vector_is
 
+    kernel_ast = insert_fma(kernel_ast, vector_is.keys())
+
     if nontemporal and 'cachelineZero' in vector_is:
         kernel_ast.use_all_written_field_sizes = True
     strided = 'storeS' in vector_is and 'loadS' in vector_is
@@ -264,7 +267,7 @@ def mask_conditionals(loop_body):
 def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
     """Inserts necessary casts from scalar values to vector values."""
 
-    handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
+    handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs, fmadd, fmsub, fnmadd, fnmsub)
 
     def is_scalar(expr) -> bool:
         if hasattr(expr, "dtype"):
diff --git a/src/pystencils/fast_approximation.py b/src/pystencils/fast_approximation.py
index ab0dc59740e9ec7fcd3e59eb826979cd5350aa3f..de3dc62c3c4cbf9a8a8c5234292dfdd45c38fa21 100644
--- a/src/pystencils/fast_approximation.py
+++ b/src/pystencils/fast_approximation.py
@@ -2,7 +2,7 @@ from typing import List, Union
 
 import sympy as sp
 
-from pystencils.astnodes import Node
+from pystencils.astnodes import Node, ResolvedFieldAccess, Block
 from pystencils.simp import AssignmentCollection
 from pystencils.assignment import Assignment
 
@@ -31,6 +31,42 @@ class fast_inv_sqrt(sp.Function):
     nargs = (1, )
 
 
+class Fma(sp.Function):
+    nargs = (3, )
+
+
+# noinspection PyPep8Naming
+class fmadd(Fma):
+    """
+    Produces FMA instructions for a*b+c
+    """
+    instruction = '*+'
+
+
+# noinspection PyPep8Naming
+class fmsub(Fma):
+    """
+    Produces FMA instructions for a*b-c
+    """
+    instruction = '*-'
+
+
+# noinspection PyPep8Naming
+class fnmadd(Fma):
+    """
+    Produces FMA instructions for -a*b+c
+    """
+    instruction = '-*+'
+
+
+# noinspection PyPep8Naming
+class fnmsub(Fma):
+    """
+    Produces FMA instructions for -a*b-c
+    """
+    instruction = '-*-'
+
+
 def _run(term, visitor):
     if isinstance(term, AssignmentCollection):
         new_main_assignments = _run(term.main_assignments, visitor)
@@ -82,3 +118,108 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
             return expr.func(*new_args) if new_args else expr
 
     return _run(term, visit)
+
+
+def insert_fma(term, operators):
+    if '*+' not in operators:
+        return term
+
+    def flatten(expr):
+        if isinstance(expr, sp.Add):
+            new_args = []
+            for arg in expr.args:
+                if arg.func == sp.Add:
+                    new_args += [flatten(a) for a in arg.args]
+                else:
+                    new_args.append(flatten(arg))
+            return sp.Add(*new_args)
+        elif isinstance(expr, sp.Mul):
+            new_args = []
+            for arg in expr.args:
+                if arg.func == sp.Mul:
+                    new_args += [flatten(a) for a in arg.args]
+                else:
+                    new_args.append(flatten(arg))
+            return sp.Mul(*new_args)
+        return expr
+
+    def visit(expr):
+        if isinstance(expr, ResolvedFieldAccess):
+            return expr
+        elif hasattr(expr, 'body'):
+            old_parent = expr.body.parent if hasattr(expr.body, 'parent') else None
+            expr.body = visit(expr.body)
+            if old_parent is not None:
+                expr.body.parent = old_parent
+            return expr
+        elif isinstance(expr, Block):
+            return Block([visit(a) for a in expr.args])
+        elif expr.func == sp.Add:
+            expr = flatten(expr)
+            summands = list(expr.args)
+            if '-*+' in operators:
+                for summand in expr.args:
+                    if summand.func == sp.Mul and len(summand.args) >= 3 and -1 in summand.args:
+                        summands.remove(summand)
+                        factors = list(summand.args)
+                        factors.remove(-1)
+                        factors = [visit(f) for f in factors]
+                        if not ('-*-' in operators and all(s.func == sp.Mul and -1 in s.args for s in summands)):
+                            summands = [visit(s) for s in summands]
+                            return sp.Add(fnmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:])
+            summands = list(expr.args)
+            if '-*-' in operators:
+                negative = [s for s in summands if s.func == sp.Mul and -1 in s.args]
+                positive = []
+                if len(negative) > 1:
+                    positive = []
+                    for summand in negative:
+                        summands.remove(summand)
+                        factors = list(summand.args)
+                        factors.remove(-1)
+                        positive.append(sp.Mul(*factors))
+                for summand in positive:
+                    if summand.func == sp.Mul and len(summand.args) >= 2:
+                        positive.remove(summand)
+                        factors = list(summand.args)
+                        positive = [visit(s) for s in positive]
+                        summands = [visit(f) for f in summands]
+                        return sp.Add(fnmsub(factors[0], sp.Mul(*factors[1:]), sp.Add(*positive[:2])),
+                                      -sp.Add(*positive[2:]), *summands)
+            summands = list(expr.args)
+            if '*-' in operators:
+                for summand in summands:
+                    if summand.func == sp.Mul and len(summand.args) >= 2 and -1 not in summand.args:
+                        summands.remove(summand)
+                        factors = list(summand.args)
+                        for summand in summands:
+                            if summand.func == sp.Mul and -1 in summand.args:
+                                summands.remove(summand)
+                                subfactors = list(summand.args)
+                                subfactors.remove(-1)
+                                factors = [visit(f) for f in factors]
+                                summands = [visit(s) for s in summands]
+                                subfactors = [visit(f) for f in subfactors]
+                                return sp.Add(fmsub(factors[0], sp.Mul(*factors[1:]), sp.Mul(*subfactors)), *summands)
+            summands = list(expr.args)
+            for summand in summands:
+                if summand.func == sp.Mul and len(summand.args) >= 2:
+                    summands.remove(summand)
+                    factors = list(summand.args)
+                    factors = [visit(f) for f in factors]
+                    summands = [visit(s) for s in summands]
+                    return sp.Add(fmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:])
+            return expr
+        elif expr.func == sp.Mul and -1 in expr.args:
+            expr = flatten(expr)
+            factors = list(expr.args)
+            factors.remove(-1)
+            factors = [visit(f) for f in factors]
+            if '-*+' in operators:
+                return fnmadd(factors[0], sp.Mul(*factors[1:]), 0)
+            elif '-*-' in operators:
+                return fnmsub(factors[0], sp.Mul(*factors[1:]), 0)
+        new_args = [visit(a) for a in expr.args]
+        return expr.func(*new_args) if new_args else expr
+
+    return _run(term, visit)
diff --git a/tests/test_vec_fma.py b/tests/test_vec_fma.py
new file mode 100644
index 0000000000000000000000000000000000000000..357c37de56188f944a64fcff4b7657e278f845bd
--- /dev/null
+++ b/tests/test_vec_fma.py
@@ -0,0 +1,127 @@
+import pytest
+import pystencils as ps
+import numpy as np
+
+from pystencils.backends.simd_instruction_sets import (get_supported_instruction_sets, get_vector_instruction_set)
+from pystencils.fast_approximation import fmadd, fmsub, fnmadd, fnmsub, Fma
+
+supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
+
+
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_fmadd(instruction_set, dtype):
+    da = 2 * np.ones((128, 128), dtype=dtype)
+    db = 3 * np.ones((128, 128), dtype=dtype)
+    dc = 5 * np.ones((128, 128), dtype=dtype)
+    dd = np.empty((128, 128), dtype=dtype)
+
+    a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
+    update_rule = [ps.Assignment(d.center(), a.center() * b.center() + c.center())]
+
+    config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+
+    if '*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fmadd]), "No FMAs found in AST"
+
+    func = ast.compile()
+    func(a=da, b=db, c=dc, d=dd)
+    np.testing.assert_equal(dd, da * db + dc)
+
+
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_fmsub(instruction_set, dtype):
+    da = 2 * np.ones((128, 128), dtype=dtype)
+    db = 3 * np.ones((128, 128), dtype=dtype)
+    dc = 5 * np.ones((128, 128), dtype=dtype)
+    dd = np.empty((128, 128), dtype=dtype)
+
+    a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
+    update_rule = [ps.Assignment(d.center(), a.center() * b.center() - c.center())]
+
+    config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+
+    if '*-' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fmsub]), "No FMAs found in AST"
+    elif '*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set([fmsub]), "No FMAs found in AST"
+
+    func = ast.compile()
+    func(a=da, b=db, c=dc, d=dd)
+    np.testing.assert_equal(dd, da * db - dc)
+
+
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_fnmadd(instruction_set, dtype):
+    da = 2 * np.ones((128, 128), dtype=dtype)
+    db = 3 * np.ones((128, 128), dtype=dtype)
+    dc = 5 * np.ones((128, 128), dtype=dtype)
+    dd = np.empty((128, 128), dtype=dtype)
+
+    a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
+    update_rule = [ps.Assignment(d.center(), -a.center() * b.center() + c.center())]
+
+    config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+
+    if '-*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fnmadd]), "No FMAs found in AST"
+    elif '*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set([fmsub]), "No FMAs found in AST"
+
+    func = ast.compile()
+    func(a=da, b=db, c=dc, d=dd)
+    np.testing.assert_equal(dd, -da * db + dc)
+
+
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_fnmsub(instruction_set, dtype):
+    da = 2 * np.ones((128, 128), dtype=dtype)
+    db = 3 * np.ones((128, 128), dtype=dtype)
+    dc = 5 * np.ones((128, 128), dtype=dtype)
+    dd = np.empty((128, 128), dtype=dtype)
+
+    a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
+    update_rule = [ps.Assignment(d.center(), -a.center() * b.center() - c.center())]
+
+    config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+
+    if '-*-' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fnmsub]), "No FMAs found in AST"
+    elif '*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set([fmsub]), "No FMAs found in AST"
+
+    func = ast.compile()
+    func(a=da, b=db, c=dc, d=dd)
+    np.testing.assert_equal(dd, -da * db - dc)
+
+
+@pytest.mark.parametrize('dtype', ('float32', 'float64'))
+@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
+def test_fnm(instruction_set, dtype):
+    da = 2 * np.ones((128, 128), dtype=dtype)
+    db = 3 * np.ones((128, 128), dtype=dtype)
+    dd = np.empty((128, 128), dtype=dtype)
+
+    a, b, d = ps.fields(a=da, b=db, d=dd)
+    update_rule = [ps.Assignment(d.center(), -a.center() * b.center())]
+
+    config = ps.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
+    ast = ps.create_kernel(update_rule, config=config)
+
+    if '-*+' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fnmadd]), "No FMAs found in AST"
+    elif '-*-' in get_vector_instruction_set(dtype, instruction_set):
+        assert set(e.func for e in ast.atoms(Fma)) == set([fnmsub]), "No FMAs found in AST"
+    else:
+        assert set(e.func for e in ast.atoms(Fma)) == set(), "Unexpected FMAs found in AST"
+
+    func = ast.compile()
+    func(a=da, b=db, d=dd)
+    np.testing.assert_equal(dd, -da * db)