Skip to content
Snippets Groups Projects
Commit 56563511 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix test cases

parent 83c93124
No related branches found
No related tags found
1 merge request!275WIP: Revamp the type system
......@@ -495,8 +495,8 @@ class CustomSympyPrinter(CCodePrinter):
known = self.known_functions[arg.__class__.__name__.lower()]
code = self._print(arg)
return code.replace(known, f"{known}f")
elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'):
known = ['sqrt', 'cbrt', 'pow']
elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'):
known = ['sqrt', 'cbrt', 'pow', 'exp']
code = self._print(arg)
for k in known:
if k in code:
......
......@@ -216,7 +216,8 @@ class TypeAdder:
else:
new_args.append(a)
return expr.func(*new_args) if new_args else expr, collated_type
elif isinstance(expr, (sp.Pow, InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)):
elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction,
HyperbolicFunction)):
args_types = [self.figure_out_type(arg) for arg in expr.args]
collated_type = collate_types([t for _, t in args_types])
new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types]
......
......@@ -4,6 +4,7 @@ import pytest
import pystencils.config
import sympy as sp
import pystencils as ps
import numpy as np
from pystencils.simp import subexpression_substitution_in_main_assignments
from pystencils.simp import add_subexpressions_for_divisions
......@@ -143,29 +144,27 @@ def test_add_subexpressions_for_field_reads():
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('simplification', (True, False))
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_sympy_optimizations(target, simplification):
def test_sympy_optimizations(target, dtype):
if target == ps.Target.GPU:
pytest.importorskip("pycuda")
src, dst = ps.fields('src, dst: float32[2d]')
src, dst = ps.fields(f'src, dst: {dtype}[2d]')
# Triggers Sympy's expm1 optimization
# Sympy's expm1 optimization is tedious to use and the behaviour is highly depended on the sympy version. In
# some cases the exp expression has to be encapsulated in brackets or multiplied with 1 or 1.0
# for sympy to work properly ...
assignments = ps.AssignmentCollection({
src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
})
config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification)
config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype)
ast = ps.create_kernel(assignments, config=config)
ps.show_code(ast)
code = ps.get_code_str(ast)
if simplification:
assert 'expm1(' in code
else:
assert 'expm1(' not in code
if dtype == 'float32':
assert 'expf(' in code
elif dtype == 'float64':
assert 'exp(' in code
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment