Skip to content
Snippets Groups Projects

Fix Opencl and LLVM GPU tests

5 files
+ 33
11
Compare changes
  • Side-by-side
  • Inline

Files

+ 10
3
@@ -5,6 +5,7 @@ import numpy as np
@@ -5,6 +5,7 @@ import numpy as np
import sympy as sp
import sympy as sp
from sympy.core import S
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from sympy.printing.ccode import C89CodePrinter
 
from pystencils.astnodes import KernelFunction, Node
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
from pystencils.data_types import (
@@ -15,6 +16,11 @@ from pystencils.integer_functions import (
@@ -15,6 +16,11 @@ from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
int_div, int_power_of_2, modulo_ceil)
 
try:
 
from sympy.boolalg import BooleanTrue, BooleanFalse
 
except Exception:
 
from sympy.logic.boolalg import BooleanTrue, BooleanFalse
 
try:
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
except ImportError:
@@ -292,9 +298,9 @@ class CBackend:
@@ -292,9 +298,9 @@ class CBackend:
return ""
return ""
def _print_Conditional(self, node):
def _print_Conditional(self, node):
if type(node.condition_expr) is sp.boolalg.BooleanTrue:
if type(node.condition_expr) is BooleanTrue:
return self._print_Block(node.true_block)
return self._print_Block(node.true_block)
elif type(node.condition_expr) is sp.boolalg.BooleanFalse:
elif type(node.condition_expr) is BooleanFalse:
return self._print_Block(node.false_block)
return self._print_Block(node.false_block)
cond_type = get_type_of_expression(node.condition_expr)
cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType):
if isinstance(cond_type, VectorType):
@@ -385,8 +391,9 @@ class CustomSympyPrinter(CCodePrinter):
@@ -385,8 +391,9 @@ class CustomSympyPrinter(CCodePrinter):
elif expr.func == int_div:
elif expr.func == int_div:
return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1]))
else:
else:
 
name = expr.name if hasattr(expr, 'name') else expr.__class__.__name__
arg_str = ', '.join(self._print(a) for a in expr.args)
arg_str = ', '.join(self._print(a) for a in expr.args)
return f'{expr.name}({arg_str})'
return f'{name}({arg_str})'
def _typed_number(self, number, dtype):
def _typed_number(self, number, dtype):
res = self._print(number)
res = self._print(number)
Loading