diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index b4b681e1c9cc233e94bd10d5f38f76eb8d34e220..f1fa4b8a141400c0880672f4fdbcd356b59d4ccd 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -10,10 +10,11 @@ from pystencils.field import Field from pystencils.node_collection import NodeCollection from pystencils.transformations import NestedScopes - +# TODO use this in Constraint Checker accepted_functions = [ sp.Pow, sp.sqrt, + sp.log, # TODO trigonometric functions (and whatever tests will fail) ] diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index b5af7e4b21ae348925f4d81cc5db196cc179bf95..ddffd61ced02b3603e7a21a784860d49127e1b5f 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -212,7 +212,7 @@ class TypeAdder: new_args.append(a) return expr.func(*new_args) if new_args else expr, collated_type elif isinstance(expr, (sp.Pow, sp.exp, InverseTrigonometricFunction, TrigonometricFunction, - HyperbolicFunction)): + HyperbolicFunction, sp.log)): args_types = [self.figure_out_type(arg) for arg in expr.args] collated_type = collate_types([t for _, t in args_types]) new_args = [a if t.dtype_eq(collated_type) else CastFunc(a, collated_type) for a, t in args_types] diff --git a/pystencils_tests/test_logarithm.py b/pystencils_tests/test_logarithm.py new file mode 100644 index 0000000000000000000000000000000000000000..85d7814a336663f76ecb40ccaf9bcc2e5ef14102 --- /dev/null +++ b/pystencils_tests/test_logarithm.py @@ -0,0 +1,26 @@ +import pytest +import numpy as np +import sympy as sp + +import pystencils as ps + + +@pytest.mark.parametrize('dtype', ["float64", "float32"]) +def test_log(dtype): + a = sp.Symbol("a") + x = ps.fields(f'x: {dtype}[1d]') + + assignments = ps.AssignmentCollection({x.center(): sp.log(a)}) + + ast = ps.create_kernel(assignments) + code = ps.get_code_str(ast) + kernel = ast.compile() + + # ps.show_code(ast) + + if dtype == "float64": + assert "float" not in code + + array = np.zeros((10,), dtype=dtype) + kernel(x=array, a=100) + assert np.allclose(array, 4.60517019)