diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 03ab76ea9e711d8945618975e85061aa9736ca96..ad2fd4b7522394c74ae66f7895fdd32769f528f1 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -495,8 +495,8 @@ class CustomSympyPrinter(CCodePrinter): known = self.known_functions[arg.__class__.__name__.lower()] code = self._print(arg) return code.replace(known, f"{known}f") - elif isinstance(arg, sp.Pow) and data_type == BasicType('float32'): - known = ['sqrt', 'cbrt', 'pow'] + elif isinstance(arg, (sp.Pow, sp.exp)) and data_type == BasicType('float32'): + known = ['sqrt', 'cbrt', 'pow', 'exp'] code = self._print(arg) for k in known: if k in code: diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index b4648835a662027b124a6c5b0192f67b76da5980..0d133038688f722d7d428c01442db2d8fb2458a9 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -216,7 +216,8 @@ class TypeAdder: else: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type - elif isinstance(expr, (sp.Pow, InverseTrigonometricFunction, TrigonometricFunction, HyperbolicFunction)): + elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, + HyperbolicFunction)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] diff --git a/pystencils_tests/test_simplifications.py b/pystencils_tests/test_simplifications.py index fbce59aca95db083a35a29d4f2e0335767c4f43f..21a5d5a9b54114288ea33468951d88dbee9dfc1c 100644 --- a/pystencils_tests/test_simplifications.py +++ b/pystencils_tests/test_simplifications.py @@ -4,6 +4,7 @@ import pytest import pystencils.config import sympy as sp import pystencils as ps +import numpy as np from pystencils.simp import subexpression_substitution_in_main_assignments from pystencils.simp import add_subexpressions_for_divisions @@ -143,29 +144,27 @@ def test_add_subexpressions_for_field_reads(): @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) -@pytest.mark.parametrize('simplification', (True, False)) +@pytest.mark.parametrize('dtype', ('float32', 'float64')) @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason") -def test_sympy_optimizations(target, simplification): +def test_sympy_optimizations(target, dtype): if target == ps.Target.GPU: pytest.importorskip("pycuda") - src, dst = ps.fields('src, dst: float32[2d]') + src, dst = ps.fields(f'src, dst: {dtype}[2d]') - # Triggers Sympy's expm1 optimization - # Sympy's expm1 optimization is tedious to use and the behaviour is highly depended on the sympy version. In - # some cases the exp expression has to be encapsulated in brackets or multiplied with 1 or 1.0 - # for sympy to work properly ... assignments = ps.AssignmentCollection({ src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1) }) - config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification) + config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype) ast = ps.create_kernel(assignments, config=config) + ps.show_code(ast) + code = ps.get_code_str(ast) - if simplification: - assert 'expm1(' in code - else: - assert 'expm1(' not in code + if dtype == 'float32': + assert 'expf(' in code + elif dtype == 'float64': + assert 'exp(' in code @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))