Skip to content
Snippets Groups Projects

WIP: Astnodes for interpolation

Files
27
@@ -9,11 +9,12 @@ from sympy.printing.ccode import C89CodePrinter
@@ -9,11 +9,12 @@ from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
vector_memory_access)
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
 
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
from pystencils.kernelparameters import FieldPointerSymbol
try:
try:
@@ -124,6 +125,12 @@ class CustomCodeNode(Node):
@@ -124,6 +125,12 @@ class CustomCodeNode(Node):
def undefined_symbols(self):
def undefined_symbols(self):
return self._symbols_read - self._symbols_defined
return self._symbols_read - self._symbols_defined
 
def __eq___(self, other):
 
return self._code == other._code
 
 
def __hash__(self):
 
return hash(self._code)
 
class PrintNode(CustomCodeNode):
class PrintNode(CustomCodeNode):
# noinspection SpellCheckingInspection
# noinspection SpellCheckingInspection
@@ -168,7 +175,8 @@ class CBackend:
@@ -168,7 +175,8 @@ class CBackend:
raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
raise NotImplementedError(self.__class__ + " does not support node of type " + str(type(node)))
def _print_KernelFunction(self, node):
def _print_KernelFunction(self, node):
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in node.get_parameters()]
parameters = node.get_parameters()
 
function_arguments = ["%s %s" % (str(s.symbol.dtype), s.symbol.name) for s in parameters]
launch_bounds = ""
launch_bounds = ""
if self.__class__ == 'cuda':
if self.__class__ == 'cuda':
max_threads = node.indexing.max_threads_per_block()
max_threads = node.indexing.max_threads_per_block()
@@ -247,6 +255,12 @@ class CBackend:
@@ -247,6 +255,12 @@ class CBackend:
def _print_CustomCodeNode(self, node):
def _print_CustomCodeNode(self, node):
return node.get_code(self._dialect, self._vector_instruction_set)
return node.get_code(self._dialect, self._vector_instruction_set)
 
def _print_SourceCodeComment(self, node):
 
return "/* " + node.text + " */"
 
 
def _print_EmptyLine(self, node):
 
return ""
 
def _print_Conditional(self, node):
def _print_Conditional(self, node):
cond_type = get_type_of_expression(node.condition_expr)
cond_type = get_type_of_expression(node.condition_expr)
if isinstance(cond_type, VectorType):
if isinstance(cond_type, VectorType):
@@ -369,10 +383,61 @@ class CustomSympyPrinter(CCodePrinter):
@@ -369,10 +383,61 @@ class CustomSympyPrinter(CCodePrinter):
res += ".0f"
res += ".0f"
else:
else:
res += "f"
res += "f"
 
else:
 
if '.' not in res:
 
res += "."
return res
return res
else:
else:
return res
return res
 
def _print_Sum(self, expr):
 
template = jinja2.Template(
 
"""[&]() {
 
{{dtype}} sum = ({{dtype}}) 0;
 
for ( {{iterator_dtype}} {{var}} = {{start}}; {{condition}}; {{var}} += {{increment}} ) {
 
sum += {{expr}};
 
}
 
return sum;
 
}()""")
 
var = expr.limits[0][0]
 
start = expr.limits[0][1]
 
end = expr.limits[0][2]
 
code = template.render(
 
dtype=get_type_of_expression(expr.args[0]),
 
iterator_dtype='int',
 
var=self._print(var),
 
start=self._print(start),
 
end=self._print(end),
 
expr=self._print(expr.function),
 
increment=str(1),
 
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
 
)
 
return code
 
 
def _print_Product(self, expr):
 
template = jinja2.Template(
 
"""[&]() {
 
{{dtype}} product = ({{dtype}}) 1;
 
for ( {{iterator_dtype}} {{var}} = {{start}}; {{condition}}; {{var}} += {{increment}} ) {
 
product *= {{expr}};
 
}
 
return product;
 
}()""")
 
var = expr.limits[0][0]
 
start = expr.limits[0][1]
 
end = expr.limits[0][2]
 
code = template.render(
 
dtype=get_type_of_expression(expr.args[0]),
 
iterator_dtype='int',
 
var=self._print(var),
 
start=self._print(start),
 
end=self._print(end),
 
expr=self._print(expr.function),
 
increment=str(1),
 
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
 
)
 
return code
 
_print_Max = C89CodePrinter._print_Max
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
_print_Min = C89CodePrinter._print_Min