Skip to content
Snippets Groups Projects

Fix RNG vectorization for LB

Merged Michael Kuron requested to merge random-vectorization into master
2 files
+ 24
21
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -192,7 +192,9 @@ class CBackend:
@@ -192,7 +192,9 @@ class CBackend:
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'):
if sympy_printer is None:
if sympy_printer is None:
if vector_instruction_set is not None:
if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
self.vector_sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set)
 
self.scalar_sympy_printer = CustomSympyPrinter()
 
self.sympy_printer = self.vector_sympy_printer
else:
else:
self.sympy_printer = CustomSympyPrinter()
self.sympy_printer = CustomSympyPrinter()
else:
else:
@@ -259,6 +261,12 @@ class CBackend:
@@ -259,6 +261,12 @@ class CBackend:
prefix = "\n".join(node.prefix_lines)
prefix = "\n".join(node.prefix_lines)
if prefix:
if prefix:
prefix += "\n"
prefix += "\n"
 
if self._vector_instruction_set and hasattr(node, 'instruction_set') and node.instruction_set is None:
 
# the tail loop must not be vectorized
 
self.sympy_printer = self.scalar_sympy_printer
 
code = f"{prefix}{loop_str}\n{self._print(node.body)}"
 
self.sympy_printer = self.vector_sympy_printer
 
return code
return f"{prefix}{loop_str}\n{self._print(node.body)}"
return f"{prefix}{loop_str}\n{self._print(node.body)}"
def _print_SympyAssignment(self, node):
def _print_SympyAssignment(self, node):
@@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
@@ -670,7 +678,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
else:
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_boolean = get_type_of_expression(arg) == create_type("bool")
is_integer = get_type_of_expression(arg) == create_type("int") or \
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int())
(isinstance(arg, TypedSymbol) and not isinstance(arg.dtype, VectorType) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
return self.instruction_set[instruction].format(self._print(arg), **self._kwargs)
Loading