Skip to content
Snippets Groups Projects
Commit f71e4754 authored by Stephan Seitz's avatar Stephan Seitz Committed by Stephan Seitz
Browse files

Save WIP

parent 1de99ef0
No related branches found
No related tags found
No related merge requests found
...@@ -17,17 +17,24 @@ from pystencils.integer_functions import ( ...@@ -17,17 +17,24 @@ from pystencils.integer_functions import (
int_div, int_power_of_2, modulo_ceil) int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
from sympy.printing.codeprinter import requires
try: try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError: except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1 from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter'] __all__ = [
'generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers',
'CustomSympyPrinter'
]
KERNCRAFT_NO_TERNARY_MODE = False KERNCRAFT_NO_TERNARY_MODE = False
def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom_backend=None) -> str: def generate_c(ast_node: Node,
signature_only: bool = False,
dialect='c',
custom_backend=None) -> str:
"""Prints an abstract syntax tree node as C or CUDA code. """Prints an abstract syntax tree node as C or CUDA code.
This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded This function does not need to distinguish between C, C++ or CUDA code, it just prints 'C-like' code as encoded
...@@ -76,14 +83,28 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom ...@@ -76,14 +83,28 @@ def generate_c(ast_node: Node, signature_only: bool = False, dialect='c', custom
def get_global_declarations(ast): def get_global_declarations(ast):
global_declarations = [] global_declarations = []
# global_declarations.append(
# CustomCodeNode(
# "using namespace std::complex_literals;", set(),
# set()))
def visit_node(sub_ast): def visit_node(sub_ast):
nonlocal global_declarations
if hasattr(sub_ast, "required_global_declarations"): if hasattr(sub_ast, "required_global_declarations"):
nonlocal global_declarations
global_declarations += sub_ast.required_global_declarations global_declarations += sub_ast.required_global_declarations
if hasattr(sub_ast, "args"): if hasattr(sub_ast, "args"):
for node in sub_ast.args: for node in sub_ast.args:
visit_node(node) visit_node(node)
if isinstance(sub_ast, KernelFunction):
if any(
np.issubdtype(a.dtype.numpy_dtype, np.complexfloating)
for a in sub_ast.atoms(sp.Expr) if hasattr(a, 'dtype')
and hasattr(a.dtype, 'numpy_dtype')):
if sub_ast.backend == 'cpu':
global_declarations.append(
CustomCodeNode(
"using namespace std::complex_literals;", set(),
set()))
visit_node(ast) visit_node(ast)
...@@ -94,9 +115,17 @@ def get_headers(ast_node: Node) -> Set[str]: ...@@ -94,9 +115,17 @@ def get_headers(ast_node: Node) -> Set[str]:
"""Return a set of header files, necessary to compile the printed C-like code.""" """Return a set of header files, necessary to compile the printed C-like code."""
headers = set() headers = set()
headers.update({'"complex_helper.hpp"'})
if isinstance(ast_node, KernelFunction) and ast_node.instruction_set: if isinstance(ast_node, KernelFunction) and ast_node.instruction_set:
headers.update(ast_node.instruction_set['headers']) headers.update(ast_node.instruction_set['headers'])
if isinstance(ast_node, KernelFunction):
if any(
np.issubdtype(a.dtype.numpy_dtype, np.complexfloating)
for a in ast_node.atoms(sp.Symbol) if hasattr(a,'dtype') and hasattr(a.dtype, 'numpy_dtype')):
if ast_node.backend == 'c':
headers.update({"<complex>"})
if hasattr(ast_node, 'headers'): if hasattr(ast_node, 'headers'):
headers.update(ast_node.headers) headers.update(ast_node.headers)
for a in ast_node.args: for a in ast_node.args:
...@@ -136,8 +165,11 @@ class CustomCodeNode(Node): ...@@ -136,8 +165,11 @@ class CustomCodeNode(Node):
class PrintNode(CustomCodeNode): class PrintNode(CustomCodeNode):
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
def __init__(self, symbol_to_print): def __init__(self, symbol_to_print):
code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (symbol_to_print.name, symbol_to_print.name) code = '\nstd::cout << "%s = " << %s << std::endl; \n' % (
super(PrintNode, self).__init__(code, symbols_read=[symbol_to_print], symbols_defined=set()) symbol_to_print.name, symbol_to_print.name)
super(PrintNode, self).__init__(code,
symbols_read=[symbol_to_print],
symbols_defined=set())
self.headers.append("<iostream>") self.headers.append("<iostream>")
...@@ -146,11 +178,15 @@ class PrintNode(CustomCodeNode): ...@@ -146,11 +178,15 @@ class PrintNode(CustomCodeNode):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CBackend: class CBackend:
def __init__(self,
def __init__(self, sympy_printer=None, signature_only=False, vector_instruction_set=None, dialect='c'): sympy_printer=None,
signature_only=False,
vector_instruction_set=None,
dialect='c'):
if sympy_printer is None: if sympy_printer is None:
if vector_instruction_set is not None: if vector_instruction_set is not None:
self.sympy_printer = VectorizedCustomSympyPrinter(vector_instruction_set) self.sympy_printer = VectorizedCustomSympyPrinter(
vector_instruction_set)
else: else:
self.sympy_printer = CustomSympyPrinter() self.sympy_printer = CustomSympyPrinter()
else: else:
...@@ -175,20 +211,25 @@ class CBackend: ...@@ -175,20 +211,25 @@ class CBackend:
method_name = "_print_" + cls.__name__ method_name = "_print_" + cls.__name__
if hasattr(self, method_name): if hasattr(self, method_name):
return getattr(self, method_name)(node) return getattr(self, method_name)(node)
raise NotImplementedError(self.__class__.__name__ + " does not support node of type " + node.__class__.__name__) raise NotImplementedError(self.__class__.__name__ +
" does not support node of type " +
node.__class__.__name__)
def _print_Type(self, node): def _print_Type(self, node):
return str(node) return str(node)
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (self._print(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()] function_arguments = [
"%s %s" % (self._print(s.symbol.dtype), s.symbol.name)
for s in node.get_parameters()
]
launch_bounds = "" launch_bounds = ""
if self._dialect == 'cuda': if self._dialect == 'cuda':
max_threads = node.indexing.max_threads_per_block() max_threads = node.indexing.max_threads_per_block()
if max_threads: if max_threads:
launch_bounds = "__launch_bounds__({}) ".format(max_threads) launch_bounds = "__launch_bounds__({}) ".format(max_threads)
func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (launch_bounds, node.function_name, func_declaration = "FUNC_PREFIX %svoid %s(%s)" % (
", ".join(function_arguments)) launch_bounds, node.function_name, ", ".join(function_arguments))
if self._signatureOnly: if self._signatureOnly:
return func_declaration return func_declaration
...@@ -197,16 +238,22 @@ class CBackend: ...@@ -197,16 +238,22 @@ class CBackend:
def _print_Block(self, node): def _print_Block(self, node):
block_contents = "\n".join([self._print(child) for child in node.args]) block_contents = "\n".join([self._print(child) for child in node.args])
return "{\n%s\n}" % (self._indent + self._indent.join(block_contents.splitlines(True))) return "{\n%s\n}" % (
self._indent + self._indent.join(block_contents.splitlines(True)))
def _print_PragmaBlock(self, node): def _print_PragmaBlock(self, node):
return "%s\n%s" % (node.pragma_line, self._print_Block(node)) return "%s\n%s" % (node.pragma_line, self._print_Block(node))
def _print_LoopOverCoordinate(self, node): def _print_LoopOverCoordinate(self, node):
counter_symbol = node.loop_counter_name counter_symbol = node.loop_counter_name
start = "int %s = %s" % (counter_symbol, self.sympy_printer.doprint(node.start)) start = "int %s = %s" % (counter_symbol,
condition = "%s < %s" % (counter_symbol, self.sympy_printer.doprint(node.stop)) self.sympy_printer.doprint(node.start))
update = "%s += %s" % (counter_symbol, self.sympy_printer.doprint(node.step),) condition = "%s < %s" % (counter_symbol,
self.sympy_printer.doprint(node.stop))
update = "%s += %s" % (
counter_symbol,
self.sympy_printer.doprint(node.step),
)
loop_str = "for (%s; %s; %s)" % (start, condition, update) loop_str = "for (%s; %s; %s)" % (start, condition, update)
prefix = "\n".join(node.prefix_lines) prefix = "\n".join(node.prefix_lines)
...@@ -221,11 +268,13 @@ class CBackend: ...@@ -221,11 +268,13 @@ class CBackend:
else: else:
prefix = '' prefix = ''
data_type = prefix + self._print(node.lhs.dtype) + " " data_type = prefix + self._print(node.lhs.dtype) + " "
return "%s%s = %s;" % (data_type, self.sympy_printer.doprint(node.lhs), return "%s%s = %s;" % (data_type,
self.sympy_printer.doprint(node.lhs),
self.sympy_printer.doprint(node.rhs)) self.sympy_printer.doprint(node.rhs))
else: else:
lhs_type = get_type_of_expression(node.lhs) lhs_type = get_type_of_expression(node.lhs)
if type(lhs_type) is VectorType and isinstance(node.lhs, cast_func): if type(lhs_type) is VectorType and isinstance(
node.lhs, cast_func):
arg, data_type, aligned, nontemporal = node.lhs.args arg, data_type, aligned, nontemporal = node.lhs.args
instr = 'storeU' instr = 'storeU'
if aligned: if aligned:
...@@ -237,10 +286,12 @@ class CBackend: ...@@ -237,10 +286,12 @@ class CBackend:
else: else:
rhs = node.rhs rhs = node.rhs
return self._vector_instruction_set[instr].format("&" + self.sympy_printer.doprint(node.lhs.args[0]), return self._vector_instruction_set[instr].format(
self.sympy_printer.doprint(rhs)) + ';' "&" + self.sympy_printer.doprint(node.lhs.args[0]),
self.sympy_printer.doprint(rhs)) + ';'
else: else:
return "%s = %s;" % (self.sympy_printer.doprint(node.lhs), self.sympy_printer.doprint(node.rhs)) return "%s = %s;" % (self.sympy_printer.doprint(
node.lhs), self.sympy_printer.doprint(node.rhs))
def _print_TemporaryMemoryAllocation(self, node): def _print_TemporaryMemoryAllocation(self, node):
align = 64 align = 64
...@@ -256,7 +307,8 @@ class CBackend: ...@@ -256,7 +307,8 @@ class CBackend:
def _print_TemporaryMemoryFree(self, node): def _print_TemporaryMemoryFree(self, node):
align = 64 align = 64
return "free(%s - %d);" % (self.sympy_printer.doprint(node.symbol.name), node.offset(align)) return "free(%s - %d);" % (self.sympy_printer.doprint(
node.symbol.name), node.offset(align))
def _print_SkipIteration(self, _): def _print_SkipIteration(self, _):
return "continue;" return "continue;"
...@@ -267,7 +319,9 @@ class CBackend: ...@@ -267,7 +319,9 @@ class CBackend:
def _print_Conditional(self, node): def _print_Conditional(self, node):
cond_type = get_type_of_expression(node.condition_expr) cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType): if isinstance(cond_type, VectorType):
raise ValueError("Problem with Conditional inside vectorized loop - use vec_any or vec_all") raise ValueError(
"Problem with Conditional inside vectorized loop - use vec_any or vec_all"
)
condition_expr = self.sympy_printer.doprint(node.condition_expr) condition_expr = self.sympy_printer.doprint(node.condition_expr)
true_block = self._print_Block(node.true_block) true_block = self._print_Block(node.true_block)
result = "if (%s)\n%s " % (condition_expr, true_block) result = "if (%s)\n%s " % (condition_expr, true_block)
...@@ -279,14 +333,13 @@ class CBackend: ...@@ -279,14 +333,13 @@ class CBackend:
def _print_DestructuringBindingsForFieldClass(self, node): def _print_DestructuringBindingsForFieldClass(self, node):
# Define all undefined symbols # Define all undefined symbols
undefined_field_symbols = node.symbols_defined undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s %s = %s.%s;" % destructuring_bindings = [
(u.dtype, "%s %s = %s.%s;" %
u.name, (u.dtype, u.name, u.field_name if hasattr(u, 'field_name') else
u.field_name if hasattr(u, 'field_name') else u.field_names[0], u.field_names[0], node.CLASS_TO_MEMBER_DICT[u.__class__] %
node.CLASS_TO_MEMBER_DICT[u.__class__] % (() if type(u) == FieldPointerSymbol else (u.coordinate, )))
(() if type(u) == FieldPointerSymbol else (u.coordinate,))) for u in undefined_field_symbols
for u in undefined_field_symbols ]
]
destructuring_bindings.sort() # only for code aesthetics destructuring_bindings.sort() # only for code aesthetics
return "{\n" + self._indent + \ return "{\n" + self._indent + \
("\n" + self._indent).join(destructuring_bindings) + \ ("\n" + self._indent).join(destructuring_bindings) + \
...@@ -300,7 +353,6 @@ class CBackend: ...@@ -300,7 +353,6 @@ class CBackend:
# noinspection PyPep8Naming # noinspection PyPep8Naming
class CustomSympyPrinter(CCodePrinter): class CustomSympyPrinter(CCodePrinter):
def __init__(self): def __init__(self):
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32") self._float_type = create_type("float32")
...@@ -312,12 +364,16 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -312,12 +364,16 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
if not expr.free_symbols: if not expr.free_symbols:
return self._typed_number(expr.evalf(), get_type_of_expression(expr)) return self._typed_number(expr.evalf(),
get_type_of_expression(expr))
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return "(" + self._print(
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
return "1 / ({})".format(self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0:
return "1 / ({})".format(
self._print(sp.Mul(*[expr.base] * (-expr.exp),
evaluate=False)))
else: else:
return super(CustomSympyPrinter, self)._print_Pow(expr) return super(CustomSympyPrinter, self)._print_Pow(expr)
...@@ -328,7 +384,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -328,7 +384,8 @@ class CustomSympyPrinter(CCodePrinter):
def _print_Equality(self, expr): def _print_Equality(self, expr):
"""Equality operator is not printable in default printer""" """Equality operator is not printable in default printer"""
return '((' + self._print(expr.lhs) + ") == (" + self._print(expr.rhs) + '))' return '((' + self._print(expr.lhs) + ") == (" + self._print(
expr.rhs) + '))'
def _print_Piecewise(self, expr): def _print_Piecewise(self, expr):
"""Print piecewise in one line (remove newlines)""" """Print piecewise in one line (remove newlines)"""
...@@ -347,9 +404,11 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -347,9 +404,11 @@ class CustomSympyPrinter(CCodePrinter):
return expr.to_c(self._print) return expr.to_c(self._print)
if isinstance(expr, reinterpret_cast_func): if isinstance(expr, reinterpret_cast_func):
arg, data_type = expr.args arg, data_type = expr.args
return "*((%s)(& %s))" % (PointerType(data_type, restrict=False), self._print(arg)) return "*((%s)(& %s))" % (PointerType(
data_type, restrict=False), self._print(arg))
elif isinstance(expr, address_of): elif isinstance(expr, address_of):
assert len(expr.args) == 1, "address_of must only have one argument" assert len(
expr.args) == 1, "address_of must only have one argument"
return "&(%s)" % self._print(expr.args[0]) return "&(%s)" % self._print(expr.args[0])
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
...@@ -366,11 +425,14 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -366,11 +425,14 @@ class CustomSympyPrinter(CCodePrinter):
elif isinstance(expr, fast_inv_sqrt): elif isinstance(expr, fast_inv_sqrt):
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return "({})".format(self._print(1 / sp.sqrt(expr.args[0])))
elif expr.func in infix_functions: elif expr.func in infix_functions:
return "(%s %s %s)" % (self._print(expr.args[0]), infix_functions[expr.func], self._print(expr.args[1])) return "(%s %s %s)" % (self._print(
expr.args[0]), infix_functions[expr.func],
self._print(expr.args[1]))
elif expr.func == int_power_of_2: elif expr.func == int_power_of_2:
return "(1 << (%s))" % (self._print(expr.args[0])) return "(1 << (%s))" % (self._print(expr.args[0]))
elif expr.func == int_div: elif expr.func == int_div:
return "((%s) / (%s))" % (self._print(expr.args[0]), self._print(expr.args[1])) return "((%s) / (%s))" % (self._print(
expr.args[0]), self._print(expr.args[1]))
else: else:
return super(CustomSympyPrinter, self)._print_Function(expr) return super(CustomSympyPrinter, self)._print_Function(expr)
...@@ -381,9 +443,9 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -381,9 +443,9 @@ class CustomSympyPrinter(CCodePrinter):
elif dtype.numpy_dtype == np.float64: elif dtype.numpy_dtype == np.float64:
return res + '.0' if '.' not in res else res return res + '.0' if '.' not in res else res
elif dtype.numpy_dtype == np.complex64: elif dtype.numpy_dtype == np.complex64:
return f"{_typed_number(number.real, np.float32)} + {_typed_number(number.real, np.float32).replace('f', 'if')}" return f"{self._typed_number(number.real, np.float32)} + {self._typed_number(number.real, np.float32).replace('f', 'if')}"
elif dtype.numpy_dtype == np.complex128: elif dtype.numpy_dtype == np.complex128:
return f"{_typed_number(number.real, np.float64)} + {_typed_number(number.real, np.float64)}i" return f"{self._typed_number(number.real, np.float64)} + {self._typed_number(number.real, np.float64)}i"
else: else:
return res return res
...@@ -406,7 +468,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -406,7 +468,8 @@ class CustomSympyPrinter(CCodePrinter):
end=self._print(end), end=self._print(end),
expr=self._print(expr.function), expr=self._print(expr.function),
increment=str(1), increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' condition=self._print(var) + ' <= ' +
self._print(end) # if start < end else '>='
) )
return code return code
...@@ -429,12 +492,15 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -429,12 +492,15 @@ class CustomSympyPrinter(CCodePrinter):
end=self._print(end), end=self._print(end),
expr=self._print(expr.function), expr=self._print(expr.function),
increment=str(1), increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>=' condition=self._print(var) + ' <= ' +
self._print(end) # if start < end else '>='
) )
return code return code
_print_Max = C89CodePrinter._print_Max _print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min _print_Min = C89CodePrinter._print_Min
@requires(headers={'stdbool.h'})
def _print_re(self, expr): def _print_re(self, expr):
return f"std::real({self._print(expr.args[0])})" return f"std::real({self._print(expr.args[0])})"
...@@ -442,8 +508,11 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -442,8 +508,11 @@ class CustomSympyPrinter(CCodePrinter):
return f"std::imag({self._print(expr.args[0])})" return f"std::imag({self._print(expr.args[0])})"
def _print_ImaginaryUnit(self, expr): def _print_ImaginaryUnit(self, expr):
return "1if" return "std::complex<double>{0,1}"
def _print_Complex(self, expr):
return self._typed_number(expr, np.complex64)
# noinspection PyPep8Naming # noinspection PyPep8Naming
class VectorizedCustomSympyPrinter(CustomSympyPrinter): class VectorizedCustomSympyPrinter(CustomSympyPrinter):
...@@ -456,7 +525,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -456,7 +525,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _scalarFallback(self, func_name, expr, *args, **kwargs): def _scalarFallback(self, func_name, expr, *args, **kwargs):
expr_type = get_type_of_expression(expr) expr_type = get_type_of_expression(expr)
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
return getattr(super(VectorizedCustomSympyPrinter, self), func_name)(expr, *args, **kwargs) return getattr(super(VectorizedCustomSympyPrinter, self),
func_name)(expr, *args, **kwargs)
else: else:
assert self.instruction_set['width'] == expr_type.width assert self.instruction_set['width'] == expr_type.width
return None return None
...@@ -464,7 +534,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -464,7 +534,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
def _print_Function(self, expr): def _print_Function(self, expr):
if isinstance(expr, vector_memory_access): if isinstance(expr, vector_memory_access):
arg, data_type, aligned, _ = expr.args arg, data_type, aligned, _ = expr.args
instruction = self.instruction_set['loadA'] if aligned else self.instruction_set['loadU'] instruction = self.instruction_set[
'loadA'] if aligned else self.instruction_set['loadU']
return instruction.format("& " + self._print(arg)) return instruction.format("& " + self._print(arg))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
arg, data_type = expr.args arg, data_type = expr.args
...@@ -473,7 +544,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -473,7 +544,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
elif expr.func == fast_division: elif expr.func == fast_division:
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
result = self.instruction_set['/'].format(self._print(expr.args[0]), self._print(expr.args[1])) result = self.instruction_set['/'].format(
self._print(expr.args[0]), self._print(expr.args[1]))
return result return result
elif expr.func == fast_sqrt: elif expr.func == fast_sqrt:
return "({})".format(self._print(sp.sqrt(expr.args[0]))) return "({})".format(self._print(sp.sqrt(expr.args[0])))
...@@ -481,21 +553,25 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -481,21 +553,25 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Function', expr) result = self._scalarFallback('_print_Function', expr)
if not result: if not result:
if self.instruction_set['rsqrt']: if self.instruction_set['rsqrt']:
return self.instruction_set['rsqrt'].format(self._print(expr.args[0])) return self.instruction_set['rsqrt'].format(
self._print(expr.args[0]))
else: else:
return "({})".format(self._print(1 / sp.sqrt(expr.args[0]))) return "({})".format(self._print(1 /
sp.sqrt(expr.args[0])))
elif isinstance(expr, vec_any): elif isinstance(expr, vec_any):
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
return self._print(expr.args[0]) return self._print(expr.args[0])
else: else:
return self.instruction_set['any'].format(self._print(expr.args[0])) return self.instruction_set['any'].format(
self._print(expr.args[0]))
elif isinstance(expr, vec_all): elif isinstance(expr, vec_all):
expr_type = get_type_of_expression(expr.args[0]) expr_type = get_type_of_expression(expr.args[0])
if type(expr_type) is not VectorType: if type(expr_type) is not VectorType:
return self._print(expr.args[0]) return self._print(expr.args[0])
else: else:
return self.instruction_set['all'].format(self._print(expr.args[0])) return self.instruction_set['all'].format(
self._print(expr.args[0]))
return super(VectorizedCustomSympyPrinter, self)._print_Function(expr) return super(VectorizedCustomSympyPrinter, self)._print_Function(expr)
...@@ -545,7 +621,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -545,7 +621,8 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
assert len(summands) >= 2 assert len(summands) >= 2
processed = summands[0].term processed = summands[0].term
for summand in summands[1:]: for summand in summands[1:]:
func = self.instruction_set['-'] if summand.sign == -1 else self.instruction_set['+'] func = self.instruction_set[
'-'] if summand.sign == -1 else self.instruction_set['+']
processed = func.format(processed, summand.term) processed = func.format(processed, summand.term)
return processed return processed
...@@ -557,18 +634,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -557,18 +634,22 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
one = self.instruction_set['makeVec'].format(1.0) one = self.instruction_set['makeVec'].format(1.0)
if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8: if expr.exp.is_integer and expr.exp.is_number and 0 < expr.exp < 8:
return "(" + self._print(sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")" return "(" + self._print(
sp.Mul(*[expr.base] * expr.exp, evaluate=False)) + ")"
elif expr.exp == -1: elif expr.exp == -1:
one = self.instruction_set['makeVec'].format(1.0) one = self.instruction_set['makeVec'].format(1.0)
return self.instruction_set['/'].format(one, self._print(expr.base)) return self.instruction_set['/'].format(one,
self._print(expr.base))
elif expr.exp == 0.5: elif expr.exp == 0.5:
return self.instruction_set['sqrt'].format(self._print(expr.base)) return self.instruction_set['sqrt'].format(self._print(expr.base))
elif expr.exp == -0.5: elif expr.exp == -0.5:
root = self.instruction_set['sqrt'].format(self._print(expr.base)) root = self.instruction_set['sqrt'].format(self._print(expr.base))
return self.instruction_set['/'].format(one, root) return self.instruction_set['/'].format(one, root)
elif expr.exp.is_integer and expr.exp.is_number and - 8 < expr.exp < 0: elif expr.exp.is_integer and expr.exp.is_number and -8 < expr.exp < 0:
return self.instruction_set['/'].format(one, return self.instruction_set['/'].format(
self._print(sp.Mul(*[expr.base] * (-expr.exp), evaluate=False))) one,
self._print(sp.Mul(*[expr.base] * (-expr.exp),
evaluate=False)))
else: else:
raise ValueError("Generic exponential not supported: " + str(expr)) raise ValueError("Generic exponential not supported: " + str(expr))
...@@ -612,14 +693,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -612,14 +693,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
if len(b) > 0: if len(b) > 0:
denominator_str = b_str[0] denominator_str = b_str[0]
for item in b_str[1:]: for item in b_str[1:]:
denominator_str = self.instruction_set['*'].format(denominator_str, item) denominator_str = self.instruction_set['*'].format(
denominator_str, item)
result = self.instruction_set['/'].format(result, denominator_str) result = self.instruction_set['/'].format(result, denominator_str)
if inside_add: if inside_add:
return sign, result return sign, result
else: else:
if sign < 0: if sign < 0:
return self.instruction_set['*'].format(self._print(S.NegativeOne), result) return self.instruction_set['*'].format(
self._print(S.NegativeOne), result)
else: else:
return result return result
...@@ -627,13 +710,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -627,13 +710,15 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._scalarFallback('_print_Relational', expr) result = self._scalarFallback('_print_Relational', expr)
if result: if result:
return result return result
return self.instruction_set[expr.rel_op].format(self._print(expr.lhs), self._print(expr.rhs)) return self.instruction_set[expr.rel_op].format(
self._print(expr.lhs), self._print(expr.rhs))
def _print_Equality(self, expr): def _print_Equality(self, expr):
result = self._scalarFallback('_print_Equality', expr) result = self._scalarFallback('_print_Equality', expr)
if result: if result:
return result return result
return self.instruction_set['=='].format(self._print(expr.lhs), self._print(expr.rhs)) return self.instruction_set['=='].format(self._print(expr.lhs),
self._print(expr.rhs))
def _print_Piecewise(self, expr): def _print_Piecewise(self, expr):
result = self._scalarFallback('_print_Piecewise', expr) result = self._scalarFallback('_print_Piecewise', expr)
...@@ -651,13 +736,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -651,13 +736,17 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self._print(expr.args[-1][0]) result = self._print(expr.args[-1][0])
for true_expr, condition in reversed(expr.args[:-1]): for true_expr, condition in reversed(expr.args[:-1]):
if isinstance(condition, cast_func) and get_type_of_expression(condition.args[0]) == create_type("bool"): if isinstance(condition, cast_func) and get_type_of_expression(
condition.args[0]) == create_type("bool"):
if not KERNCRAFT_NO_TERNARY_MODE: if not KERNCRAFT_NO_TERNARY_MODE:
result = "(({}) ? ({}) : ({}))".format(self._print(condition.args[0]), self._print(true_expr), result = "(({}) ? ({}) : ({}))".format(
result) self._print(condition.args[0]), self._print(true_expr),
result)
else: else:
print("Warning - skipping ternary op") print("Warning - skipping ternary op")
else: else:
# noinspection SpellCheckingInspection # noinspection SpellCheckingInspection
result = self.instruction_set['blendv'].format(result, self._print(true_expr), self._print(condition)) result = self.instruction_set['blendv'].format(
result, self._print(true_expr), self._print(condition))
return result
return result return result
...@@ -123,7 +123,7 @@ class boolean_cast_func(cast_func, Boolean): ...@@ -123,7 +123,7 @@ class boolean_cast_func(cast_func, Boolean):
# noinspection PyPep8Naming # noinspection PyPep8Naming
class vector_memory_access(cast_func): class vector_memory_access(cast_func):
nargs = (4,) nargs = (4, )
# noinspection PyPep8Naming # noinspection PyPep8Naming
...@@ -172,7 +172,8 @@ class TypedSymbol(sp.Symbol): ...@@ -172,7 +172,8 @@ class TypedSymbol(sp.Symbol):
@property @property
def is_integer(self): def is_integer(self):
if hasattr(self.dtype, 'numpy_dtype'): if hasattr(self.dtype, 'numpy_dtype'):
return np.issubdtype(self.dtype.numpy_dtype, np.integer) or super().is_integer return np.issubdtype(self.dtype.numpy_dtype,
np.integer) or super().is_integer
else: else:
return super().is_integer return super().is_integer
...@@ -292,12 +293,10 @@ to_ctypes.map = { ...@@ -292,12 +293,10 @@ to_ctypes.map = {
np.dtype(np.int16): ctypes.c_int16, np.dtype(np.int16): ctypes.c_int16,
np.dtype(np.int32): ctypes.c_int32, np.dtype(np.int32): ctypes.c_int32,
np.dtype(np.int64): ctypes.c_int64, np.dtype(np.int64): ctypes.c_int64,
np.dtype(np.uint8): ctypes.c_uint8, np.dtype(np.uint8): ctypes.c_uint8,
np.dtype(np.uint16): ctypes.c_uint16, np.dtype(np.uint16): ctypes.c_uint16,
np.dtype(np.uint32): ctypes.c_uint32, np.dtype(np.uint32): ctypes.c_uint32,
np.dtype(np.uint64): ctypes.c_uint64, np.dtype(np.uint64): ctypes.c_uint64,
np.dtype(np.float32): ctypes.c_float, np.dtype(np.float32): ctypes.c_float,
np.dtype(np.float64): ctypes.c_double, np.dtype(np.float64): ctypes.c_double,
} }
...@@ -330,7 +329,8 @@ def ctypes_from_llvm(data_type): ...@@ -330,7 +329,8 @@ def ctypes_from_llvm(data_type):
elif isinstance(data_type, ir.VoidType): elif isinstance(data_type, ir.VoidType):
return None # Void type is not supported by ctypes return None # Void type is not supported by ctypes
else: else:
raise NotImplementedError('Data type %s of %s is not supported yet' % (type(data_type), data_type)) raise NotImplementedError('Data type %s of %s is not supported yet' %
(type(data_type), data_type))
def to_llvm_type(data_type): def to_llvm_type(data_type):
...@@ -353,12 +353,10 @@ if ir: ...@@ -353,12 +353,10 @@ if ir:
np.dtype(np.int16): ir.IntType(16), np.dtype(np.int16): ir.IntType(16),
np.dtype(np.int32): ir.IntType(32), np.dtype(np.int32): ir.IntType(32),
np.dtype(np.int64): ir.IntType(64), np.dtype(np.int64): ir.IntType(64),
np.dtype(np.uint8): ir.IntType(8), np.dtype(np.uint8): ir.IntType(8),
np.dtype(np.uint16): ir.IntType(16), np.dtype(np.uint16): ir.IntType(16),
np.dtype(np.uint32): ir.IntType(32), np.dtype(np.uint32): ir.IntType(32),
np.dtype(np.uint64): ir.IntType(64), np.dtype(np.uint64): ir.IntType(64),
np.dtype(np.float32): ir.FloatType(), np.dtype(np.float32): ir.FloatType(),
np.dtype(np.float64): ir.DoubleType(), np.dtype(np.float64): ir.DoubleType(),
} }
...@@ -370,11 +368,21 @@ def peel_off_type(dtype, type_to_peel_off): ...@@ -370,11 +368,21 @@ def peel_off_type(dtype, type_to_peel_off):
return dtype return dtype
def collate_types(types): def collate_types(types, forbid_collation_to_complex=False, forbid_collation_to_float=False):
""" """
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy. Uses the collation rules from numpy.
""" """
if forbid_collation_to_complex:
types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.complexfloating)]
if not types:
types = [ create_type(np.float64)]
if forbid_collation_to_float:
types = [t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)]
if not types:
types = [ create_type(np.int64) ]
# Pointer arithmetic case i.e. pointer + integer is allowed # Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types): if any(type(t) is PointerType for t in types):
...@@ -382,7 +390,8 @@ def collate_types(types): ...@@ -382,7 +390,8 @@ def collate_types(types):
for t in types: for t in types:
if type(t) is PointerType: if type(t) is PointerType:
if pointer_type is not None: if pointer_type is not None:
raise ValueError("Cannot collate the combination of two pointer types") raise ValueError(
"Cannot collate the combination of two pointer types")
pointer_type = t pointer_type = t
elif type(t) is BasicType: elif type(t) is BasicType:
if not (t.is_int() or t.is_uint()): if not (t.is_int() or t.is_uint()):
...@@ -394,7 +403,8 @@ def collate_types(types): ...@@ -394,7 +403,8 @@ def collate_types(types):
# peel of vector types, if at least one vector type occurred the result will also be the vector type # peel of vector types, if at least one vector type occurred the result will also be the vector type
vector_type = [t for t in types if type(t) is VectorType] vector_type = [t for t in types if type(t) is VectorType]
if not all_equal(t.width for t in vector_type): if not all_equal(t.width for t in vector_type):
raise ValueError("Collation failed because of vector types with different width") raise ValueError(
"Collation failed because of vector types with different width")
types = [peel_off_type(t, VectorType) for t in types] types = [peel_off_type(t, VectorType) for t in types]
# now we should have a list of basic types - struct types are not yet supported # now we should have a list of basic types - struct types are not yet supported
...@@ -429,6 +439,8 @@ def get_type_of_expression(expr, ...@@ -429,6 +439,8 @@ def get_type_of_expression(expr,
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
return create_type(default_int_type) return create_type(default_int_type)
elif expr.is_real == False:
return create_type(default_complex_type)
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type(default_float_type) return create_type(default_float_type)
elif isinstance(expr, ResolvedFieldAccess): elif isinstance(expr, ResolvedFieldAccess):
...@@ -453,7 +465,8 @@ def get_type_of_expression(expr, ...@@ -453,7 +465,8 @@ def get_type_of_expression(expr,
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
typed_symbol = expr.base.label typed_symbol = expr.base.label
return typed_symbol.dtype.base_type return typed_symbol.dtype.base_type
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): elif isinstance(expr, sp.boolalg.Boolean) or isinstance(
expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool") result = create_type("bool")
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)] vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
...@@ -499,14 +512,15 @@ class BasicType(Type): ...@@ -499,14 +512,15 @@ class BasicType(Type):
return 'std::complex<double>' return 'std::complex<double>'
elif name.startswith('int'): elif name.startswith('int'):
width = int(name[len("int"):]) width = int(name[len("int"):])
return "int%d_t" % (width,) return "int%d_t" % (width, )
elif name.startswith('uint'): elif name.startswith('uint'):
width = int(name[len("uint"):]) width = int(name[len("uint"):])
return "uint%d_t" % (width,) return "uint%d_t" % (width, )
elif name == 'bool': elif name == 'bool':
return 'bool' return 'bool'
else: else:
raise NotImplementedError("Can map numpy to C name for %s" % (name,)) raise NotImplementedError("Can map numpy to C name for %s" %
(name, ))
def __init__(self, dtype, const=False): def __init__(self, dtype, const=False):
self.const = const self.const = const
...@@ -534,7 +548,8 @@ class BasicType(Type): ...@@ -534,7 +548,8 @@ class BasicType(Type):
return 1 return 1
def is_int(self): def is_int(self):
return self.numpy_dtype in np.sctypes['int'] or self.numpy_dtype in np.sctypes['uint'] return self.numpy_dtype in np.sctypes[
'int'] or self.numpy_dtype in np.sctypes['uint']
def is_float(self): def is_float(self):
return self.numpy_dtype in np.sctypes['float'] return self.numpy_dtype in np.sctypes['float']
...@@ -565,7 +580,8 @@ class BasicType(Type): ...@@ -565,7 +580,8 @@ class BasicType(Type):
if not isinstance(other, BasicType): if not isinstance(other, BasicType):
return False return False
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype,
other.const)
def __hash__(self): def __hash__(self):
return hash(str(self)) return hash(str(self))
...@@ -590,7 +606,8 @@ class VectorType(Type): ...@@ -590,7 +606,8 @@ class VectorType(Type):
if not isinstance(other, VectorType): if not isinstance(other, VectorType):
return False return False
else: else:
return (self.base_type, self.width) == (other.base_type, other.width) return (self.base_type, self.width) == (other.base_type,
other.width)
def __str__(self): def __str__(self):
if self.instruction_set is None: if self.instruction_set is None:
...@@ -639,7 +656,9 @@ class PointerType(Type): ...@@ -639,7 +656,9 @@ class PointerType(Type):
if not isinstance(other, PointerType): if not isinstance(other, PointerType):
return False return False
else: else:
return (self.base_type, self.const, self.restrict) == (other.base_type, other.const, other.restrict) return (self.base_type, self.const,
self.restrict) == (other.base_type, other.const,
other.restrict)
def __str__(self): def __str__(self):
components = [str(self.base_type), '*'] components = [str(self.base_type), '*']
...@@ -690,7 +709,8 @@ class StructType: ...@@ -690,7 +709,8 @@ class StructType:
if not isinstance(other, StructType): if not isinstance(other, StructType):
return False return False
else: else:
return (self.numpy_dtype, self.const) == (other.numpy_dtype, other.const) return (self.numpy_dtype, self.const) == (other.numpy_dtype,
other.const)
def __str__(self): def __str__(self):
# structs are handled byte-wise # structs are handled byte-wise
......
/*
* complex_helper.hpp
* Copyright (C) 2019 Stephan Seitz <stephan.seitz@fau.de>
*
* Distributed under terms of the GPLv3 license.
*/
#pragma once
#include <complex>
template <class U, class V>
auto operator*(const std::complex<U> &complexNumber, const V &scalar)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) * scalar,
std::imag(complexNumber) * scalar};
}
template <class U, class V>
auto operator*(const V &scalar, const std::complex<U> &complexNumber)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) * scalar,
std::imag(complexNumber) * scalar};
}
template <class U, class V>
auto operator+(const std::complex<U> &complexNumber, const V &scalar)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) + scalar,
std::imag(complexNumber)};
}
template <class U, class V>
auto operator+(const V &scalar, const std::complex<U> &complexNumber)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) + scalar,
std::imag(complexNumber)};
}
template <class U, class V>
auto operator-(const std::complex<U> &complexNumber, const V &scalar)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) - scalar,
std::imag(complexNumber)};
}
template <class U, class V>
auto operator-(const V &scalar, const std::complex<U> &complexNumber)
-> std::complex<U> {
return std::complex<U>{scalar - std::real(complexNumber),
std::imag(complexNumber)};
}
template <class U, class V>
auto operator/(const std::complex<U> &complexNumber, const V &scalar)
-> std::complex<U> {
return std::complex<U>{std::real(complexNumber) / scalar,
std::imag(complexNumber) / scalar};
}
template <class U, class V>
auto operator/(const V &scalar, const std::complex<U> &complexNumber)
-> std::complex<U> {
return std::complex<U>{scalar, 0} / complexNumber;
}
...@@ -887,7 +887,7 @@ class KernelConstraintsCheck: ...@@ -887,7 +887,7 @@ class KernelConstraintsCheck:
def process_assignment(self, assignment): def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs) new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs) new_lhs = self._process_lhs(assignment.lhs, assignment.rhs)
return ast.SympyAssignment(new_lhs, new_rhs) return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs, type_constants=True): def process_expression(self, rhs, type_constants=True):
...@@ -933,7 +933,7 @@ class KernelConstraintsCheck: ...@@ -933,7 +933,7 @@ class KernelConstraintsCheck:
def fields_written(self): def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v)) return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs): def _process_lhs(self, lhs, rhs):
assert isinstance(lhs, sp.Symbol) assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs) self._update_accesses_lhs(lhs)
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import pytest
import sympy
from sympy.functions import im, re
import pystencils
from pystencils import AssignmentCollection
from pystencils.data_types import create_type, TypedSymbol
X, Y = pystencils.fields('x, y: complex64[2d]')
A, B = pystencils.fields('a, b: float32[2d]')
S1, S2 = sympy.symbols('S1, S2')
T64 = TypedSymbol('t', create_type('complex64'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T64: 2 + 4j,
Y.center: X.center / T64,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
def test_complex_numbers(assignment, scalar_dtypes):
ast = pystencils.create_kernel(assignment,
target='cpu',
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
X, Y = pystencils.fields('x, y: complex128[2d]')
A, B = pystencils.fields('a, b: float64[2d]')
S1, S2 = sympy.symbols('S1, S2')
T128 = TypedSymbol('t', create_type('complex128'))
TEST_ASSIGNMENTS = [
AssignmentCollection({X[0, 0]: 1j}),
AssignmentCollection({
S1: re(Y.center),
S2: im(Y.center),
X[0, 0]: 2j * S1 + S2
}),
AssignmentCollection({
A.center: re(Y.center),
B.center: im(Y.center),
}),
AssignmentCollection({
Y.center: re(Y.center) + X.center + 2j,
}),
AssignmentCollection({
T128: 2 + 4j,
Y.center: X.center / T128,
})
]
SCALAR_DTYPES = ['float32', 'float64']
@pytest.mark.parametrize("assignment, scalar_dtypes",
itertools.product(TEST_ASSIGNMENTS, SCALAR_DTYPES))
def test_complex_numbers_64(assignment, scalar_dtypes):
ast = pystencils.create_kernel(assignment,
target='cpu',
data_type=scalar_dtypes)
code = str(pystencils.show_code(ast))
print(code)
assert "Not supported" not in code
kernel = ast.compile()
assert kernel is not None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment