Skip to content
Snippets Groups Projects

Improve Vectorisation

Merged Markus Holzer requested to merge holzer/pystencils:ImproveVec into master
3 unresolved threads
9 files
+ 227
70
Compare changes
  • Side-by-side
  • Inline
Files
9
@@ -629,6 +629,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_CastFunc(self, expr):
arg, data_type = expr.args
if type(data_type) is VectorType:
base_type = data_type.base_type
# vector_memory_access is a cast_func itself so it should't be directly inside a cast_func
assert not isinstance(arg, VectorMemoryAccess)
if isinstance(arg, sp.Tuple):
@@ -648,19 +649,18 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif isinstance(arg, TypedSymbol):
return self._typed_vectorized_symbol(arg, data_type)
elif isinstance(arg, (InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)) \
and data_type == BasicType('float32'):
and base_type == BasicType('float32'):
raise NotImplementedError('Vectorizer is not tested for trigonometric functions yet')
# known = self.known_functions[arg.__class__.__name__.lower()]
# code = self._print(arg)
# return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'):
raise NotImplementedError('Vectorizer cannot print casted aka. not double pow')
# known = ['sqrt', 'cbrt', 'pow']
# code = self._print(arg)
# for k in known:
# if k in code:
# return code.replace(k, f'{k}f')
# raise ValueError(f"{code} doesn't give {known=} function back.")
elif isinstance(arg, sp.Pow):
if base_type == BasicType('float32') or base_type == BasicType('float64'):
return self._print_Pow(arg)
else:
raise NotImplementedError('Integer Pow is not implemented')
elif isinstance(arg, sp.UnevaluatedExpr):
return self._print(arg.args[0])
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
@@ -770,6 +770,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return processed
def _print_Pow(self, expr):
# Due to loop cutting sp.Mul is evaluated again.
try:
result = self._scalarFallback('_print_Pow', expr)
except ValueError:
@@ -778,26 +780,21 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return result
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
if isinstance(expr.exp, CastFunc) and expr.exp.args[0].is_number:
exp = expr.exp.args[0]
else:
exp = expr.exp
# TODO the printer should not have any intelligence like this.
# TODO To remove all of these cases the vectoriser needs to be reworked. See loop cutting
if exp.is_integer and exp.is_number and 0 < exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * exp, evaluate=False)) + ")"
elif exp == -1:
one = self.instruction_set['makeVecConst'].format(1.0, **self._kwargs)
return self.instruction_set['/'].format(one, self._print(expr.base), **self._kwargs)
return self._print(sp.Mul(*[expr.base] * exp, evaluate=False))
elif exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
return root
elif exp == -0.5:
root = self.instruction_set['sqrt'].format(self._print(expr.base), **self._kwargs)
return self.instruction_set['/'].format(one, root, **self._kwargs)
elif exp.is_integer and exp.is_number and - 8 < exp < 0:
return self.instruction_set['/'].format(one,
self._print(sp.Mul(*[expr.base] * (-exp), evaluate=False)),
**self._kwargs)
else:
raise ValueError("Generic exponential not supported: " + str(expr))
Loading