Skip to content
Snippets Groups Projects
Commit 8bebd100 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Fix FMA insertion failures in remainder loops and with RNG

parent ff0e41ee
Branches
1 merge request!414Draft: Fused-multiply-add vectorization
Pipeline #68927 failed with stages
in 7 minutes and 11 seconds
......@@ -571,6 +571,11 @@ class CustomSympyPrinter(CCodePrinter):
return f"(({self._print(expr.args[0])}) / ({self._print(expr.args[1])}))"
elif expr.func == DivFunc:
return f'(({self._print(expr.divisor)}) / ({self._print(expr.dividend)}))'
elif isinstance(expr, Fma):
a = expr.args[0] * (-1 if expr.instruction[0] == '-' else 1)
b = expr.args[1]
c = expr.args[2] * (-1 if expr.instruction[-1] == '-' else 1)
return f"fma({self._print(a)}, {self._print(b)}, {self._print(c)})"
else:
name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
arg_str = ', '.join(self._print(a) for a in expr.args)
......@@ -729,6 +734,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(expr, fast_inv_sqrt):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, Fma):
result = self._scalarFallback('_print_Function', expr)
if result:
return result
return self.instruction_set[expr.instruction].format(self._print(expr.args[0]), self._print(expr.args[1]),
self._print(expr.args[2]), **self._kwargs)
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
......
......@@ -121,6 +121,8 @@ def insert_fast_divisions(term: Union[sp.Expr, List[sp.Expr], AssignmentCollecti
def insert_fma(term, operators):
from pystencils.rng import RNGBase # late import to avoid cyclic dependency
if '*+' not in operators:
return term
......@@ -144,8 +146,11 @@ def insert_fma(term, operators):
return expr
def visit(expr):
# Special treatments for various types that cannot be reconstructed from their args
if isinstance(expr, ResolvedFieldAccess):
return expr
elif isinstance(expr, RNGBase):
return expr
elif hasattr(expr, 'body'):
old_parent = expr.body.parent if hasattr(expr.body, 'parent') else None
expr.body = visit(expr.body)
......@@ -154,7 +159,9 @@ def insert_fma(term, operators):
return expr
elif isinstance(expr, Block):
return Block([visit(a) for a in expr.args])
elif expr.func == sp.Add:
# Find patterns of Add and Mul nodes that can be fused
if expr.func == sp.Add:
expr = flatten(expr)
summands = list(expr.args)
if '-*+' in operators:
......@@ -210,6 +217,7 @@ def insert_fma(term, operators):
summands = [visit(s) for s in summands]
return sp.Add(fmadd(factors[0], sp.Mul(*factors[1:]), summands[0]), *summands[1:])
return expr
# Find Mul with three factors, one of them -1, which can be fused
elif expr.func == sp.Mul and -1 in expr.args:
expr = flatten(expr)
factors = list(expr.args)
......
......@@ -11,10 +11,10 @@ supported_instruction_sets = get_supported_instruction_sets() if get_supported_i
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_fmadd(instruction_set, dtype):
da = 2 * np.ones((128, 128), dtype=dtype)
db = 3 * np.ones((128, 128), dtype=dtype)
dc = 5 * np.ones((128, 128), dtype=dtype)
dd = np.empty((128, 128), dtype=dtype)
da = 2 * np.ones((129, 129), dtype=dtype)
db = 3 * np.ones((129, 129), dtype=dtype)
dc = 5 * np.ones((129, 129), dtype=dtype)
dd = np.empty((129, 129), dtype=dtype)
a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
update_rule = [ps.Assignment(d.center(), a.center() * b.center() + c.center())]
......@@ -33,10 +33,10 @@ def test_fmadd(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_fmsub(instruction_set, dtype):
da = 2 * np.ones((128, 128), dtype=dtype)
db = 3 * np.ones((128, 128), dtype=dtype)
dc = 5 * np.ones((128, 128), dtype=dtype)
dd = np.empty((128, 128), dtype=dtype)
da = 2 * np.ones((129, 129), dtype=dtype)
db = 3 * np.ones((129, 129), dtype=dtype)
dc = 5 * np.ones((129, 129), dtype=dtype)
dd = np.empty((129, 129), dtype=dtype)
a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
update_rule = [ps.Assignment(d.center(), a.center() * b.center() - c.center())]
......@@ -57,10 +57,10 @@ def test_fmsub(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_fnmadd(instruction_set, dtype):
da = 2 * np.ones((128, 128), dtype=dtype)
db = 3 * np.ones((128, 128), dtype=dtype)
dc = 5 * np.ones((128, 128), dtype=dtype)
dd = np.empty((128, 128), dtype=dtype)
da = 2 * np.ones((129, 129), dtype=dtype)
db = 3 * np.ones((129, 129), dtype=dtype)
dc = 5 * np.ones((129, 129), dtype=dtype)
dd = np.empty((129, 129), dtype=dtype)
a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
update_rule = [ps.Assignment(d.center(), -a.center() * b.center() + c.center())]
......@@ -81,10 +81,10 @@ def test_fnmadd(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_fnmsub(instruction_set, dtype):
da = 2 * np.ones((128, 128), dtype=dtype)
db = 3 * np.ones((128, 128), dtype=dtype)
dc = 5 * np.ones((128, 128), dtype=dtype)
dd = np.empty((128, 128), dtype=dtype)
da = 2 * np.ones((129, 129), dtype=dtype)
db = 3 * np.ones((129, 129), dtype=dtype)
dc = 5 * np.ones((129, 129), dtype=dtype)
dd = np.empty((129, 129), dtype=dtype)
a, b, c, d = ps.fields(a=da, b=db, c=dc, d=dd)
update_rule = [ps.Assignment(d.center(), -a.center() * b.center() - c.center())]
......@@ -105,9 +105,9 @@ def test_fnmsub(instruction_set, dtype):
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_fnm(instruction_set, dtype):
da = 2 * np.ones((128, 128), dtype=dtype)
db = 3 * np.ones((128, 128), dtype=dtype)
dd = np.empty((128, 128), dtype=dtype)
da = 2 * np.ones((129, 129), dtype=dtype)
db = 3 * np.ones((129, 129), dtype=dtype)
dd = np.empty((129, 129), dtype=dtype)
a, b, d = ps.fields(a=da, b=db, d=dd)
update_rule = [ps.Assignment(d.center(), -a.center() * b.center())]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment