Skip to content
Snippets Groups Projects
Commit 5a059a11 authored by Markus Holzer's avatar Markus Holzer
Browse files

Merge branch 'bauerd/fix-vectorize' into 'master'

Vectorize all scalar symbols in vector expressions

See merge request pycodegen/pystencils!344
parents 6da822aa b65d1cc1
Branches
Tags
No related merge requests found
......@@ -257,22 +257,24 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
handled_functions = (sp.Add, sp.Mul, vec_any, vec_all, DivFunc, sp.Abs)
def visit_expr(expr, default_type='double'): # TODO Vectorization Revamp: get rid of default_type
# TODO Vectorization Revamp: get rid of default_type
def visit_expr(expr, default_type='double', force_vectorize=False):
if isinstance(expr, VectorMemoryAccess):
return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type), *expr.args[5:])
return VectorMemoryAccess(*expr.args[0:4], visit_expr(expr.args[4], default_type, force_vectorize),
*expr.args[5:])
elif isinstance(expr, CastFunc):
cast_type = expr.args[1]
arg = visit_expr(expr.args[0])
arg = visit_expr(expr.args[0], default_type, force_vectorize)
assert cast_type in [BasicType('float32'), BasicType('float64')],\
f'Vectorization cannot vectorize type {cast_type}'
return expr.func(arg, VectorType(cast_type, instruction_set['width']))
elif expr.func is sp.Abs and 'abs' not in instruction_set:
new_arg = visit_expr(expr.args[0], default_type)
new_arg = visit_expr(expr.args[0], default_type, force_vectorize)
base_type = get_type_of_expression(expr.args[0]).base_type if type(expr.args[0]) is VectorMemoryAccess \
else get_type_of_expression(expr.args[0])
pw = sp.Piecewise((-new_arg, new_arg < CastFunc(0, base_type.numpy_dtype)),
(new_arg, True))
return visit_expr(pw, default_type)
return visit_expr(pw, default_type, force_vectorize)
elif expr.func in handled_functions or isinstance(expr, sp.Rel) or isinstance(expr, BooleanFunction):
if expr.func is sp.Mul and expr.args[0] == -1:
# special treatment for the unary minus: make sure that the -1 has the same type as the argument
......@@ -287,7 +289,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
if dtype is np.float32:
default_type = 'float'
expr = sp.Mul(dtype(expr.args[0]), *expr.args[1:])
new_args = [visit_expr(a, default_type) for a in expr.args]
new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args]
arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
if not any(type(t) is VectorType for t in arg_types):
return expr
......@@ -306,7 +308,7 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
exp = expr.args[0].exp
expr = sp.UnevaluatedExpr(sp.Mul(*([base] * +exp), evaluate=False))
new_args = [visit_expr(a, default_type) for a in expr.args[0].args]
new_args = [visit_expr(a, default_type, force_vectorize) for a in expr.args[0].args]
arg_types = [get_type_of_expression(a, default_float_type=default_type) for a in new_args]
target_type = collate_types(arg_types)
......@@ -318,11 +320,11 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
for a, t in zip(new_args, arg_types)]
return expr.func(expr.args[0].func(*casted_args, evaluate=False))
elif expr.func is sp.Pow:
new_arg = visit_expr(expr.args[0], default_type)
new_arg = visit_expr(expr.args[0], default_type, force_vectorize)
return expr.func(new_arg, expr.args[1])
elif expr.func == sp.Piecewise:
new_results = [visit_expr(a[0], default_type) for a in expr.args]
new_conditions = [visit_expr(a[1], default_type) for a in expr.args]
new_results = [visit_expr(a[0], default_type, force_vectorize) for a in expr.args]
new_conditions = [visit_expr(a[1], default_type, force_vectorize) for a in expr.args]
types_of_results = [get_type_of_expression(a) for a in new_results]
types_of_conditions = [get_type_of_expression(a) for a in new_conditions]
......@@ -341,7 +343,14 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
for a, t in zip(new_conditions, types_of_conditions)]
return sp.Piecewise(*[(r, c) for r, c in zip(casted_results, casted_conditions)])
elif isinstance(expr, (sp.Number, TypedSymbol, BooleanAtom)):
elif isinstance(expr, TypedSymbol):
if force_vectorize:
expr_type = get_type_of_expression(expr)
if type(expr_type) is not VectorType:
vector_type = VectorType(expr_type, instruction_set['width'])
return CastFunc(expr, vector_type)
return expr
elif isinstance(expr, (sp.Number, BooleanAtom)):
return expr
else:
raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
......@@ -357,11 +366,18 @@ def insert_vector_casts(ast_node, instruction_set, default_float_type='double'):
# continue
subs_expr = fast_subs(assignment.rhs, substitution_dict,
skip=lambda e: isinstance(e, ast.ResolvedFieldAccess))
assignment.rhs = visit_expr(subs_expr, default_type)
rhs_type = get_type_of_expression(assignment.rhs)
# If either side contains a vectorized subexpression, both sides
# must be fully vectorized.
lhs_type = get_type_of_expression(assignment.lhs)
rhs_type = get_type_of_expression(subs_expr)
lhs_vectorized = type(lhs_type) is VectorType
rhs_vectorized = type(rhs_type) is VectorType
assignment.rhs = visit_expr(subs_expr, default_type, force_vectorize=lhs_vectorized or rhs_vectorized)
if isinstance(assignment.lhs, TypedSymbol):
lhs_type = assignment.lhs.dtype
if type(rhs_type) is VectorType and type(lhs_type) is not VectorType:
if rhs_vectorized and not lhs_vectorized:
new_lhs_type = VectorType(lhs_type, rhs_type.width)
new_lhs = TypedSymbol(assignment.lhs.name, new_lhs_type)
substitution_dict[assignment.lhs] = new_lhs
......
......@@ -6,6 +6,7 @@ import pystencils.config
import sympy as sp
import pystencils as ps
import pystencils.astnodes as ast
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
from pystencils.cpu.vectorization import vectorize
from pystencils.enums import Target
......@@ -40,6 +41,47 @@ def test_vector_type_propagation(instruction_set=instruction_set):
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_vectorize_moved_constants1(instruction_set=instruction_set):
opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
f = ps.fields("f: [1D]")
x = ast.TypedSymbol("x", np.float64)
kernel_func = ps.create_kernel(
[ast.SympyAssignment(x, 2.0), ast.SympyAssignment(f[0], x)],
cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants
cpu_vectorize_info=opt,
)
ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized
kernel = kernel_func.compile()
f_arr = np.zeros(9)
kernel(f=f_arr)
assert(np.all(f_arr == 2))
def test_vectorize_moved_constants2(instruction_set=instruction_set):
opt = {'instruction_set': instruction_set, 'assume_inner_stride_one': True}
f = ps.fields("f: [1D]")
x = ast.TypedSymbol("x", np.float64)
y = ast.TypedSymbol("y", np.float64)
kernel_func = ps.create_kernel(
[ast.SympyAssignment(x, 2.0), ast.SympyAssignment(y, 3.0), ast.SympyAssignment(f[0], x + y)],
cpu_prepend_optimizations=[ps.transformations.move_constants_before_loop], # explicitly move constants
cpu_vectorize_info=opt,
)
ps.show_code(kernel_func) # fails if `x` on rhs was not correctly vectorized
kernel = kernel_func.compile()
f_arr = np.zeros(9)
kernel(f=f_arr)
assert(np.all(f_arr == 5))
@pytest.mark.parametrize('openmp', [True, False])
def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
domain_size = (24, 24)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment