Skip to content
Snippets Groups Projects
Commit be262db4 authored by Markus Holzer's avatar Markus Holzer
Browse files

Implemented float math functions

parent 2f1c1194
No related branches found
No related tags found
No related merge requests found
...@@ -8,6 +8,7 @@ import sympy as sp ...@@ -8,6 +8,7 @@ import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
...@@ -493,6 +494,10 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -493,6 +494,10 @@ class CustomSympyPrinter(CCodePrinter):
arg, data_type = expr.args arg, data_type = expr.args
if isinstance(arg, sp.Number) and arg.is_finite: if isinstance(arg, sp.Number) and arg.is_finite:
return self._typed_number(arg, data_type) return self._typed_number(arg, data_type)
elif isinstance(arg, (sp.Pow, InverseTrigonometricFunction, TrigonometricFunction)) and data_type == BasicType('float32'):
known = self.known_functions[arg.__class__.__name__]
code = self._print(arg)
return code.replace(known, f"{known}f")
else: else:
return f"(({data_type})({self._print(arg)}))" return f"(({data_type})({self._print(arg)}))"
elif isinstance(expr, fast_division): elif isinstance(expr, fast_division):
......
...@@ -8,6 +8,8 @@ import sympy as sp ...@@ -8,6 +8,8 @@ import sympy as sp
from sympy import Piecewise from sympy import Piecewise
from sympy.core.relational import Relational from sympy.core.relational import Relational
from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.functions.elementary.trigonometric import TrigonometricFunction
from sympy.functions.elementary.trigonometric import InverseTrigonometricFunction
from sympy.codegen import Assignment from sympy.codegen import Assignment
from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanFunction
from sympy.logic.boolalg import BooleanAtom from sympy.logic.boolalg import BooleanAtom
...@@ -185,10 +187,6 @@ class TypeAdder: ...@@ -185,10 +187,6 @@ class TypeAdder:
# # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)] # # args_types = [self.figure_out_type(arg) for arg in expr.args if arg not in (-1, 1)]
elif isinstance(expr, sp.Indexed): elif isinstance(expr, sp.Indexed):
raise NotImplementedError('sp.Indexed') raise NotImplementedError('sp.Indexed')
elif isinstance(expr, sp.Pow):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
return expr.func(*[a for a, _ in args_types]), collated_type
elif isinstance(expr, ExprCondPair): elif isinstance(expr, ExprCondPair):
expr_expr, expr_type = self.figure_out_type(expr.expr) expr_expr, expr_type = self.figure_out_type(expr.expr)
condition, condition_type = self.figure_out_type(expr.cond) condition, condition_type = self.figure_out_type(expr.cond)
...@@ -208,6 +206,15 @@ class TypeAdder: ...@@ -208,6 +206,15 @@ class TypeAdder:
else: else:
new_args.append(a) new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, InverseTrigonometricFunction, TrigonometricFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
new_func = expr.func(*new_args) if new_args else expr
if collated_type == BasicType('float64'):
return new_func, collated_type
else:
return CastFunc(new_func, collated_type), collated_type
elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)): elif isinstance(expr, (sp.Add, sp.Mul, sp.Abs, sp.Min, sp.Max, DivFunc, sp.UnevaluatedExpr)):
args_types = [self.figure_out_type(arg) for arg in expr.args] args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types]) collated_type = collate_types([t for _, t in args_types])
......
import pytest
import sympy as sp
import numpy as np
import math
import pystencils as ps
@pytest.mark.parametrize('dtype', ["float64", "float32"])
def test_trigonometric_functions(dtype):
dh = ps.create_data_handling(domain_size=(10, 10), periodicity=True)
x = dh.add_array('x', values_per_cell=1, dtype=dtype)
dh.fill("x", 0.0, ghost_layers=True)
y = dh.add_array('y', values_per_cell=1, dtype=dtype)
dh.fill("y", 1.0, ghost_layers=True)
z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2.0, ghost_layers=True)
# config = pystencils.CreateKernelConfig(default_number_float=dtype)
# test sp.Max with one argument
up = ps.Assignment(x.center, sp.atan2(y.center, z.center))
ast = ps.create_kernel(up)
code = ps.get_code_str(ast)
kernel = ast.compile()
dh.run_kernel(kernel)
np.testing.assert_allclose(dh.gather_array("x")[0, 0], math.atan2(1.0, 2.0))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment