Skip to content
Snippets Groups Projects
Commit 866e9fc0 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent 501b2d7e
Branches
Tags
No related merge requests found
......@@ -213,6 +213,7 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self, constants_as_floats=False):
self._constantsAsFloats = constants_as_floats
super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32")
def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication"""
......@@ -224,8 +225,6 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Rational(self, expr):
"""Evaluate all rationals i.e. print 0.25 instead of 1.0/4.0"""
res = str(expr.evalf().num)
if self._constantsAsFloats:
res += "f"
return res
def _print_Equality(self, expr):
......@@ -237,12 +236,6 @@ class CustomSympyPrinter(CCodePrinter):
result = super(CustomSympyPrinter, self)._print_Piecewise(expr)
return result.replace("\n", "")
def _print_Float(self, expr):
res = str(expr)
if self._constantsAsFloats:
res += "f"
return res
def _print_Function(self, expr):
function_map = {
bitwise_xor: '^',
......@@ -255,7 +248,10 @@ class CustomSympyPrinter(CCodePrinter):
return expr.to_c(self._print)
if expr.func == cast_func:
arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
if isinstance(arg, sp.Number):
return self._typed_number(arg, data_type)
else:
return "*((%s)(& %s))" % (PointerType(data_type), self._print(arg))
elif expr.func == modulo_floor:
assert all(get_type_of_expression(e).is_int() for e in expr.args)
return "({dtype})({0} / {1}) * {1}".format(*expr.args, dtype=get_type_of_expression(expr.args[0]))
......@@ -264,6 +260,17 @@ class CustomSympyPrinter(CCodePrinter):
else:
return super(CustomSympyPrinter, self)._print_Function(expr)
def _typed_number(self, number, dtype):
res = self._print(number)
if dtype.is_float:
if dtype == self._float_type:
if '.' not in res:
res += ".0f"
else:
res += "f"
return res
else:
return res
# noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter):
......
......@@ -20,7 +20,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
'sqrt': 'sqrt[0]',
'makeVec': 'set[0,0,0,0]',
'makeVec': 'set[]',
'makeZero': 'setzero[]',
'loadU': 'loadu[0]',
......@@ -31,6 +31,7 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
}
headers = {
'avx512': ['<immintrin.h>'],
'avx': ['<immintrin.h>'],
'sse': ['<xmmintrin.h>', '<emmintrin.h>', '<pmmintrin.h>', '<tmmintrin.h>', '<smmintrin.h>', '<nmmintrin.h>']
}
......@@ -54,32 +55,37 @@ def get_vector_instruction_set(data_type='double', instruction_set='avx'):
("float", "avx512"): 16,
}
result = {}
result = {
'width': width[(data_type, instruction_set)],
}
pre = prefix[instruction_set]
suf = suffix[data_type]
for intrinsic_id, function_shortcut in base_names.items():
function_shortcut = function_shortcut.strip()
name = function_shortcut[:function_shortcut.index('[')]
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
if intrinsic_id == 'makeVec':
arg_string = "({})".format(",".join(["{0}"] * result['width']))
else:
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
for arg in args.split(","):
arg = arg.strip()
if not arg:
continue
if arg in ('0', '1', '2', '3', '4', '5'):
arg_string += "{" + arg + "},"
else:
arg_string += arg + ","
arg_string = arg_string[:-1] + ")"
result[intrinsic_id] = pre + "_" + name + "_" + suf + arg_string
result['width'] = width[(data_type, instruction_set)]
result['dataTypePrefix'] = {
'double': "_" + pre + 'd',
'float': "_" + pre,
}
bit_width = result['width'] * 64
bit_width = result['width'] * (64 if data_type == 'double' else 32)
result['double'] = "__m%dd" % (bit_width,)
result['float'] = "__m%d" % (bit_width,)
result['int'] = "__m%di" % (bit_width,)
......
......@@ -13,13 +13,13 @@ from pystencils.transformations import cut_loop, filtered_tree_iteration
from pystencils.field import Field
def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx',
def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
assume_aligned: bool = False, nontemporal: Union[bool, Container[Union[str, Field]]] = False):
"""Explicit vectorization using SIMD vectorization via intrinsics.
Args:
kernel_ast: abstract syntax tree (KernelFunction node)
vector_instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
instruction_set: one of the supported vector instruction sets, currently ('sse', 'avx' and 'avx512')
assume_aligned: assume that the first inner cell of each line is aligned. If false, only unaligned-loads are
used. If true, some of the loads are assumed to be from aligned memory addresses.
For example if x is the fastest coordinate, the access to center can be fetched via an
......@@ -42,7 +42,7 @@ def vectorize(kernel_ast: ast.KernelFunction, vector_instruction_set: str = 'avx
float_size = field_float_dtypes.pop().numpy_dtype.itemsize
assert float_size in (8, 4)
vector_is = get_vector_instruction_set('double' if float_size == 8 else 'float',
instruction_set=vector_instruction_set)
instruction_set=instruction_set)
vector_width = vector_is['width']
kernel_ast.instruction_set = vector_is
......
......@@ -289,7 +289,10 @@ def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type("int")
if expr == 1 or expr == -1:
return create_type("int16")
else:
return create_type("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
elif isinstance(expr, ResolvedFieldAccess):
......@@ -316,6 +319,8 @@ def get_type_of_expression(expr):
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, sp.Expr):
types = tuple(get_type_of_expression(a) for a in expr.args)
return collate_types(types)
......
......@@ -73,7 +73,7 @@ def create_kernel(assignments, target='cpu', data_type="double", iteration_slice
add_openmp(ast, num_threads=cpu_openmp)
if cpu_vectorize_info:
if cpu_vectorize_info is True:
vectorize(ast, vector_instruction_set='avx', assume_aligned=False, nontemporal=None)
vectorize(ast, instruction_set='avx', assume_aligned=False, nontemporal=None)
elif isinstance(cpu_vectorize_info, dict):
vectorize(ast, **cpu_vectorize_info)
else:
......
......@@ -205,10 +205,17 @@ class LLVMPrinter(Printer):
node = self._print(conversion.args[0])
to_dtype = get_type_of_expression(conversion)
from_dtype = get_type_of_expression(conversion.args[0])
if from_dtype == to_dtype:
return self._print(conversion.args[0])
# (From, to)
decision = {
(create_composite_type_from_string("int16"),
create_composite_type_from_string("int64")): lambda: ir.Constant(self.integer, node),
(create_composite_type_from_string("int"),
create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(create_composite_type_from_string("int16"),
create_composite_type_from_string("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
(create_composite_type_from_string("double"),
create_composite_type_from_string("int")): functools.partial(self.builder.fptosi, node, self.integer),
(create_composite_type_from_string("double *"),
......
......@@ -8,7 +8,7 @@ from sympy.tensor import IndexedBase
from pystencils.assignment import Assignment
from pystencils.field import Field, FieldType
from pystencils.data_types import TypedSymbol, PointerType, StructType, get_base_type, cast_func, \
pointer_arithmetic_func, get_type_of_expression, collate_types
pointer_arithmetic_func, get_type_of_expression, collate_types, create_type
from pystencils.slicing import normalize_slice
import pystencils.astnodes as ast
......@@ -716,9 +716,18 @@ class KernelConstraintsCheck:
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(symbol_name_to_variable_name(rhs.name), self._type_for_symbol[rhs.name])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
elif isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0]), rhs.args[1])
else:
new_args = [self.process_expression(arg) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
......@@ -800,10 +809,13 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary
"""Checks the types and inserts casts and pointer arithmetic where necessary.
:param node: the head node of the ast
:return: modified ast
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
......@@ -839,7 +851,7 @@ def insert_casts(node):
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr):
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment