Skip to content
Snippets Groups Projects

RNG SIMD

Merged Michael Kuron requested to merge philox-simd into master
Compare and
13 files
+ 1556
190
Compare changes
  • Side-by-side
  • Inline
Files
13
@@ -11,7 +11,7 @@ from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access)
reinterpret_cast_func, vector_memory_access, BasicType, TypedSymbol)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
@@ -134,7 +134,7 @@ class CustomCodeNode(Node):
self._symbols_defined = set(symbols_defined)
self.headers = []
def get_code(self, dialect, vector_instruction_set):
def get_code(self, dialect, vector_instruction_set, print_arg):
return self._code
@property
@@ -297,7 +297,7 @@ class CBackend:
return "continue;"
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
return node.get_code(self._dialect, self._vector_instruction_set, print_arg=self.sympy_printer._print)
def _print_SourceCodeComment(self, node):
return f"/* {node.text } */"
@@ -548,12 +548,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if type(data_type) is VectorType:
if isinstance(arg, sp.Tuple):
is_boolean = get_type_of_expression(arg[0]) == create_type("bool")
is_integer = get_type_of_expression(arg[0]) == create_type("int")
printed_args = [self._print(a) for a in arg]
instruction = 'makeVecBool' if is_boolean else 'makeVec'
instruction = 'makeVecBool' if is_boolean else 'makeVecInt' if is_integer else 'makeVec'
return self.instruction_set[instruction].format(*printed_args)
else:
is_boolean = get_type_of_expression(arg) == create_type("bool")
instruction = 'makeVecConstBool' if is_boolean else 'makeVecConst'
is_integer = get_type_of_expression(arg) == create_type("int") or \
(isinstance(arg, TypedSymbol) and arg.dtype.is_int())
instruction = 'makeVecConstBool' if is_boolean else \
'makeVecConstInt' if is_integer else 'makeVecConst'
return self.instruction_set[instruction].format(self._print(arg))
elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr)
@@ -609,12 +613,27 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
return result
def _print_Add(self, expr, order=None):
result = self._scalarFallback('_print_Add', expr)
try:
result = self._scalarFallback('_print_Add', expr)
except Exception:
result = None
if result:
return result
args = expr.args
# special treatment for all-integer args, for loop index arithmetic until we have proper int vectorization
suffix = ""
if all([(type(e) is cast_func and str(e.dtype) == self.instruction_set['int']) or isinstance(e, sp.Integer)
or (type(e) is TypedSymbol and isinstance(e.dtype, BasicType) and e.dtype.is_int()) for e in args]):
dtype = set([e.dtype for e in args if type(e) is cast_func])
assert len(dtype) == 1
dtype = dtype.pop()
args = [cast_func(e, dtype) if (isinstance(e, sp.Integer) or isinstance(e, TypedSymbol)) else e
for e in args]
suffix = "int"
summands = []
for term in expr.args:
for term in args:
if term.func == sp.Mul:
sign, t = self._print_Mul(term, inside_add=True)
else:
@@ -630,7 +649,7 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(summands) >= 2
processed = summands[0].term
for summand in summands[1:]:
func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+']
func = self.instruction_set['-' + suffix] if summand.sign == -1 else self.instruction_set['+' + suffix]
processed = func.format(processed, summand.term)
return processed
Loading