Skip to content
Snippets Groups Projects

Draft: Loop counter dependent kernels: Vector casts and smaller fixes

Closed Daniel Bauer requested to merge terraneo/pystencils:bauerd/vector-casts into master
6 files
+ 228
63
Compare changes
  • Side-by-side
  • Inline
Files
6
@@ -14,7 +14,7 @@ from sympy.functions.elementary.hyperbolic import HyperbolicFunction
@@ -14,7 +14,7 @@ from sympy.functions.elementary.hyperbolic import HyperbolicFunction
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.typing import (
from pystencils.typing import (
PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
PointerType, VectorType, CastFunc, collate_types, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
from pystencils.enums import Backend
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
@@ -107,7 +107,7 @@ def get_headers(ast_node: Node) -> Set[str]:
@@ -107,7 +107,7 @@ def get_headers(ast_node: Node) -> Set[str]:
headers = set()
headers = set()
if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
headers.update(ast_node.instruction_set['headers'])
headers.update(ast_node.instruction_set.headers)
if hasattr(ast_node, 'headers'):
if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers)
headers.update(ast_node.headers)
@@ -330,8 +330,8 @@ class CBackend:
@@ -330,8 +330,8 @@ class CBackend:
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
pre_code = f"if ({first_cond} && {size_cond}) " + "{\n\t" + \
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
self._vector_instruction_set['cachelineZero'].format(ptr, **self._kwargs) + ';\n}\n'
code = self._vector_instruction_set[instr].format(ptr, self.sympy_printer.doprint(rhs),
code = self._vector_instruction_set.operation(instr, data_type).format(ptr, self.sympy_printer.doprint(rhs),
printed_mask, **self._kwargs) + ';'
printed_mask, **self._kwargs) + ';'
flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
flushcond = f"((uintptr_t) {ptr} & {CachelineSize.mask_symbol}) == {CachelineSize.last_symbol}"
if nontemporal and 'flushCacheline' in self._vector_instruction_set:
if nontemporal and 'flushCacheline' in self._vector_instruction_set:
code2 = self._vector_instruction_set['flushCacheline'].format(
code2 = self._vector_instruction_set['flushCacheline'].format(
@@ -610,24 +610,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -610,24 +610,20 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if type(expr_type) is not VectorType:
if type(expr_type) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs)
else:
else:
assert self.instruction_set['width'] == expr_type.width
# assert self.instruction_set['width'] == expr_type.width
return None
return None
def _print_Abs(self, expr):
def _print_Abs(self, expr):
if 'abs' in self.instruction_set and isinstance(expr.args[0], VectorMemoryAccess):
expr_type = get_type_of_expression(expr)
return self.instruction_set['abs'].format(self._print(expr.args[0]), **self._kwargs)
if isinstance(expr_type, VectorType) and self.instruction_set.supports('abs'):
 
return self.instruction_set.operation('abs', expr_type).format(self._print(expr.args[0]), **self._kwargs)
return super()._print_Abs(expr)
return super()._print_Abs(expr)
def _typed_vectorized_number(self, expr, data_type):
def _typed_vectorized_number(self, expr, data_type):
basic_data_type = data_type.base_type
basic_data_type = data_type.base_type
number = self._typed_number(expr, basic_data_type)
number = self._typed_number(expr, basic_data_type)
instruction = 'makeVecConst'
instruction = 'makeVecConst'
if basic_data_type.is_bool():
return self.instruction_set.operation(instruction, data_type).format(number, **self._kwargs)
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(number, **self._kwargs)
def _typed_vectorized_symbol(self, expr, data_type):
def _typed_vectorized_symbol(self, expr, data_type):
if not isinstance(expr, TypedSymbol):
if not isinstance(expr, TypedSymbol):
@@ -638,12 +634,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -638,12 +634,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
symbol = f'(({basic_data_type})({symbol}))'
symbol = f'(({basic_data_type})({symbol}))'
instruction = 'makeVecConst'
instruction = 'makeVecConst'
if basic_data_type.is_bool():
return self.instruction_set.operation(instruction, data_type).format(symbol, **self._kwargs)
instruction = 'makeVecConstBool'
# TODO Vectorization Revamp: is int, or sint, or uint (my guess is sint)
elif basic_data_type.is_int():
instruction = 'makeVecConstInt'
return self.instruction_set[instruction].format(symbol, **self._kwargs)
def _print_CastFunc(self, expr):
def _print_CastFunc(self, expr):
arg, data_type = expr.args
arg, data_type = expr.args
@@ -652,16 +643,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -652,16 +643,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, VectorMemoryAccess)
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
printed_args = [self._print(a) for a in arg]
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
instruction = 'makeVec'
if instruction == 'makeVecInt' and 'makeVecIndex' in self.instruction_set:
if is_integer and self.instruction_set.supports('makeVecIndex'):
increments = np.array(arg)[1:] - np.array(arg)[:-1]
increments = np.array(arg)[1:] - np.array(arg)[:-1]
if len(set(increments)) == 1:
if len(set(increments)) == 1:
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
return self.instruction_set['makeVecIndex'].format(printed_args[0], increments[0],
**self._kwargs)
**self._kwargs)
return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
return self.instruction_set.operation(instruction, data_type).format(*printed_args, **self._kwargs)
else:
else:
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
if arg.is_Number and not isinstance(arg, (sp.core.numbers.Infinity, sp.core.numbers.NegativeInfinity)):
return self._typed_vectorized_number(arg, data_type)
return self._typed_vectorized_number(arg, data_type)
@@ -681,7 +671,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -681,7 +671,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(arg, sp.UnevaluatedExpr):
elif isinstance(arg, sp.UnevaluatedExpr):
return self._print(arg.args[0])
return self._print(arg.args[0])
else:
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
return self.instruction_set.operation('convert', data_type, [get_type_of_expression(arg)]).format(self._print(arg), **self._kwargs)
 
# raise NotImplementedError(f'Vectorizer cannot cast between different datatypes in expression "{expr}" ({sp.srepr(expr)})')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
# from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
# from_type = self.instruction_set['suffix'][get_type_of_expression(arg).base_type.c_name]
# return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
# return self.instruction_set['cast'].format(from_type, to_type, self._print(arg))
@@ -690,17 +681,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -690,17 +681,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
# raise ValueError(f'Non VectorType cast "{data_type}" in vectorized code.')
def _print_Function(self, expr):
def _print_Function(self, expr):
 
expr_type = get_type_of_expression(expr.args[0])
if isinstance(expr, VectorMemoryAccess):
if isinstance(expr, VectorMemoryAccess):
arg, data_type, aligned, _, mask, stride = expr.args
arg, data_type, aligned, _, mask, stride = expr.args
if stride != 1:
if stride != 1:
return self.instruction_set['loadS'].format(f"& {self._print(arg)}", stride, **self._kwargs)
return self.instruction_set.operation('loadS', data_type).format(f"& {self._print(arg)}", stride, **self._kwargs)
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU']
instruction = self.instruction_set.operation('loadA', data_type) if aligned else self.instruction_set.operation('loadU', data_type)
return instruction.format(f"& {self._print(arg)}", **self._kwargs)
return instruction.format(f"& {self._print(arg)}", **self._kwargs)
elif expr.func == DivFunc:
elif expr.func == DivFunc:
result = self._scalarFallback('_print_Function', expr)
result = self._scalarFallback('_print_Function', expr)
if not result:
if not result:
result = self.instruction_set['/'].format(self._print(expr.divisor), self._print(expr.dividend),
result = self.instruction_set.operation('/', expr_type).format(self._print(expr.divisor), self._print(expr.dividend),
**self._kwargs)
**self._kwargs)
return result
return result
elif isinstance(expr, fast_division):
elif isinstance(expr, fast_division):
raise ValueError("fast_division is only supported for Taget.GPU")
raise ValueError("fast_division is only supported for Taget.GPU")
@@ -710,7 +702,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -710,7 +702,6 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
raise ValueError("fast_inv_sqrt is only supported for Taget.GPU")
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
instr = 'any' if isinstance(expr, vec_any) else 'all'
instr = 'any' if isinstance(expr, vec_any) else 'all'
expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType:
if type(expr_type) is not VectorType:
return self._print(expr.args[0])
return self._print(expr.args[0])
else:
else:
@@ -757,15 +748,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -757,15 +748,13 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
args = expr.args
args = expr.args
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = ""
if all([(type(e) is CastFunc and isinstance(e.dtype, BasicType) and e.dtype.is_int()) or isinstance(e, sp.Integer)
if all([(type(e) is CastFunc and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is CastFunc])
dtype = set([e.dtype for e in args if type(e) is CastFunc])
assert len(dtype) == 1
assert len(dtype) == 1
dtype = dtype.pop()
dtype = dtype.pop()
args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
args = [CastFunc(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args]
for e in args]
suffix = "int"
summands = []
summands = []
for term in args:
for term in args:
@@ -777,19 +766,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -777,19 +766,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
summands.append(self.SummandInfo(sign, t))
summands.append(self.SummandInfo(sign, t))
# Use positive terms first
# Use positive terms first
summands.sort(key=lambda e: e.sign, reverse=True)
summands.sort(key=lambda e: e.sign, reverse=True)
 
arg_types = [get_type_of_expression(a) for a in args]
 
target_type = collate_types(arg_types)
# if no positive term exists, prepend a zero
# if no positive term exists, prepend a zero
if summands[0].sign == -1:
if summands[0].sign == -1:
summands.insert(0, self.SummandInfo(1, "0"))
summands.insert(0, self.SummandInfo(1, self._print(CastFunc(0, target_type))))
assert len(summands) >= 2
assert len(summands) >= 2
processed = summands[0].term
processed = summands[0].term
for summand in summands[1:]:
for summand in summands[1:]:
func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
op = '-' if summand.sign == -1 else '+'
processed = func.format(processed, summand.term, **self._kwargs)
processed = self.instruction_set.operation(op, target_type).format(processed, summand.term, **self._kwargs)
return processed
return processed
def _print_Pow(self, expr):
def _print_Pow(self, expr):
# Due to loop cutting sp.Mul is evaluated again.
# Due to loop cutting sp.Mul is evaluated again.
 
expr_type = get_type_of_expression(expr)
try:
try:
result = self._scalarFallback('_print_Pow', expr)
result = self._scalarFallback('_print_Pow', expr)
@@ -798,8 +790,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -798,8 +790,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if result:
if result:
return result
return result
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
one = self.instruction_set.operation('makeVecConst', expr_type).format(1.0, **self._kwargs)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
root = self.instruction_set.operation('sqrt', expr_type).format(self._print(expr.base), **self._kwargs)
if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
exp = expr.exp.args[0]
exp = expr.exp.args[0]
@@ -813,7 +805,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -813,7 +805,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif exp == 0.5:
elif exp == 0.5:
return root
return root
elif exp == -0.5:
elif exp == -0.5:
return self.instruction_set['/'].format(one, root, **self._kwargs)
return self.instruction_set.operation('/', expr_type).format(one, root, **self._kwargs)
else:
else:
raise ValueError("Generic exponential not supported: " + str(expr))
raise ValueError("Generic exponential not supported: " + str(expr))
@@ -821,6 +813,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -821,6 +813,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# noinspection PyProtectedMember
# noinspection PyProtectedMember
from sympy.core.mul import _keep_coeff
from sympy.core.mul import _keep_coeff
 
expr_type = get_type_of_expression(expr)
 
if not inside_add:
if not inside_add:
result = self._scalarFallback('_print_Mul', expr)
result = self._scalarFallback('_print_Mul', expr)
else:
else:
@@ -855,19 +849,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -855,19 +849,19 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = a_str[0]
result = a_str[0]
for item in a_str[1:]:
for item in a_str[1:]:
result = self.instruction_set['*'].format(result, item, **self._kwargs)
result = self.instruction_set.operation('*', expr_type).format(result, item, **self._kwargs)
if len(b) > 0:
if len(b) > 0:
denominator_str = b_str[0]
denominator_str = b_str[0]
for item in b_str[1:]:
for item in b_str[1:]:
denominator_str = self.instruction_set['*'].format(denominator_str, item, **self._kwargs)
denominator_str = self.instruction_set.operation('*', expr_type).format(denominator_str, item, **self._kwargs)
result = self.instruction_set['/'].format(result, denominator_str, **self._kwargs)
result = self.instruction_set.operation('/', expr_type).format(result, denominator_str, **self._kwargs)
if inside_add:
if inside_add:
return sign, result
return sign, result
else:
else:
if sign < 0:
if sign < 0:
return self.instruction_set['*'].format(self._print(S.NegativeOne), result, **self._kwargs)
return self.instruction_set.operation('*', expr_type).format(self._print(S.NegativeOne), result, **self._kwargs)
else:
else:
return result
return result
Loading