Skip to content
Snippets Groups Projects
Commit 3c02d58e authored by Markus Holzer's avatar Markus Holzer
Browse files

Implemented Min and Max printer

parent 0653f52d
No related branches found
No related tags found
1 merge request!163Volume-of-Fluid: better tests and make it actually work
Pipeline #25159 failed
...@@ -330,10 +330,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -330,10 +330,10 @@ class CustomSympyPrinter(CCodePrinter):
def __init__(self): def __init__(self):
super(CustomSympyPrinter, self).__init__() super(CustomSympyPrinter, self).__init__()
self._float_type = create_type("float32") self._float_type = create_type("float32")
if 'Min' in self.known_functions: #if 'Min' in self.known_functions:
del self.known_functions['Min'] # del self.known_functions['Min']
if 'Max' in self.known_functions: # if 'Max' not in self.known_functions:
del self.known_functions['Max'] # self.known_functions.update({'Max': 'Max'})
def _print_Pow(self, expr): def _print_Pow(self, expr):
"""Don't use std::pow function, for small integer exponents, write as multiplication""" """Don't use std::pow function, for small integer exponents, write as multiplication"""
...@@ -402,6 +402,8 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -402,6 +402,8 @@ class CustomSympyPrinter(CCodePrinter):
return f"({self._print(1 / sp.sqrt(expr.args[0]))})" return f"({self._print(1 / sp.sqrt(expr.args[0]))})"
elif isinstance(expr, sp.Abs): elif isinstance(expr, sp.Abs):
return f"abs({self._print(expr.args[0])})" return f"abs({self._print(expr.args[0])})"
elif isinstance(expr, sp.Max):
return self._print(expr)
elif isinstance(expr, sp.Mod): elif isinstance(expr, sp.Mod):
if expr.args[0].is_integer and expr.args[1].is_integer: if expr.args[0].is_integer and expr.args[1].is_integer:
return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})" return f"({self._print(expr.args[0])} % {self._print(expr.args[1])})"
...@@ -476,8 +478,25 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -476,8 +478,25 @@ class CustomSympyPrinter(CCodePrinter):
def _print_ConditionalFieldAccess(self, node): def _print_ConditionalFieldAccess(self, node):
return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True))) return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
_print_Max = C89CodePrinter._print_Max def _print_Max(self, expr):
_print_Min = C89CodePrinter._print_Min def inner_print_max(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_max(args[:half])
b = inner_print_max(args[half:])
return f"(({a} > {b}) ? {a} : {b})"
return inner_print_max(expr.args)
def _print_Min(self, expr):
def inner_print_min(args):
if len(args) == 1:
return self._print(args[0])
half = len(args) // 2
a = inner_print_min(args[:half])
b = inner_print_min(args[half:])
return f"(({a} < {b}) ? {a} : {b})"
return inner_print_min(expr.args)
def _print_re(self, expr): def _print_re(self, expr):
return f"real({self._print(expr.args[0])})" return f"real({self._print(expr.args[0])})"
...@@ -575,6 +594,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter): ...@@ -575,6 +594,9 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
result = self.instruction_set['&'].format(result, item) result = self.instruction_set['&'].format(result, item)
return result return result
def _print_Max(self, expr):
return "test"
def _print_Or(self, expr): def _print_Or(self, expr):
result = self._scalarFallback('_print_Or', expr) result = self._scalarFallback('_print_Or', expr)
if result: if result:
......
import sympy
import numpy
import pystencils
from pystencils.datahandling import create_data_handling
def test_max():
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1)
dh.fill("z", 2.0, ghost_layers=True)
# test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Max(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile()
# test sp.Max with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Max(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile()
# test sp.Max with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Max(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1)
assert numpy.all(dh.cpu_arrays["x"] == 4.3)
dh.run_kernel(kernel_2)
assert numpy.all(dh.cpu_arrays["x"] == 0.5)
dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == 4.5)
def test_min():
dh = create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1)
dh.fill("z", 2.0, ghost_layers=True)
# test sp.Min with one argument
assignment_1 = pystencils.Assignment(x.center, sympy.Min(y.center + 3.3))
ast_1 = pystencils.create_kernel(assignment_1)
kernel_1 = ast_1.compile()
# test sp.Min with two arguments
assignment_2 = pystencils.Assignment(x.center, sympy.Min(0.5, y.center - 1.5))
ast_2 = pystencils.create_kernel(assignment_2)
kernel_2 = ast_2.compile()
# test sp.Min with many arguments
assignment_3 = pystencils.Assignment(x.center, sympy.Min(z.center, 4.5, y.center - 1.5, y.center + z.center))
ast_3 = pystencils.create_kernel(assignment_3)
kernel_3 = ast_3.compile()
dh.run_kernel(kernel_1)
assert numpy.all(dh.cpu_arrays["x"] == 4.3)
dh.run_kernel(kernel_2)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
dh.run_kernel(kernel_3)
assert numpy.all(dh.cpu_arrays["x"] == - 0.5)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment