Skip to content
Snippets Groups Projects
Commit c4e92d45 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fix: type of sqrt(int) was int not floating point type

parent 8ca8b2ee
Branches
Tags
No related merge requests found
...@@ -548,7 +548,13 @@ def get_type_of_expression(expr, ...@@ -548,7 +548,13 @@ def get_type_of_expression(expr,
if vec_args: if vec_args:
result = VectorType(result, width=vec_args[0].width) result = VectorType(result, width=vec_args[0].width)
return result return result
elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)): elif isinstance(expr, sp.Pow):
base_type = get_type(expr.args[0])
if expr.exp.is_integer:
return base_type
else:
return collate_types([create_type(default_float_type), base_type])
elif isinstance(expr, (sp.Sum, sp.Product)):
return get_type(expr.args[0]) return get_type(expr.args[0])
elif isinstance(expr, sp.Expr): elif isinstance(expr, sp.Expr):
expr: sp.Expr expr: sp.Expr
......
import sympy as sp import sympy as sp
import numpy as np
import pystencils as ps
from pystencils import data_types from pystencils import data_types
from pystencils.data_types import * from pystencils.data_types import TypedSymbol, get_type_of_expression, VectorType, collate_types, create_type
from pystencils.kernelparameters import FieldShapeSymbol
def test_parsing(): def test_parsing():
...@@ -25,7 +25,6 @@ def test_collation(): ...@@ -25,7 +25,6 @@ def test_collation():
def test_dtype_of_constants(): def test_dtype_of_constants():
# Some come constants are neither of type Integer,Float,Rational and don't have args # Some come constants are neither of type Integer,Float,Rational and don't have args
# >>> isinstance(pi, Integer) # >>> isinstance(pi, Integer)
# False # False
...@@ -39,13 +38,25 @@ def test_dtype_of_constants(): ...@@ -39,13 +38,25 @@ def test_dtype_of_constants():
def test_assumptions(): def test_assumptions():
x = ps.fields('x: float32[3d]')
x = pystencils.fields('x: float32[3d]')
assert x.shape[0].is_nonnegative assert x.shape[0].is_nonnegative
assert (2 * x.shape[0]).is_nonnegative assert (2 * x.shape[0]).is_nonnegative
assert (2 * x.shape[0]).is_integer assert (2 * x.shape[0]).is_integer
assert(TypedSymbol('a', create_type('uint64'))).is_nonnegative assert (TypedSymbol('a', create_type('uint64'))).is_nonnegative
assert (TypedSymbol('a', create_type('uint64'))).is_positive is None assert (TypedSymbol('a', create_type('uint64'))).is_positive is None
assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive assert (TypedSymbol('a', create_type('uint64')) + 1).is_positive
assert (x.shape[0] + 1).is_real assert (x.shape[0] + 1).is_real
def test_sqrt_of_integer():
"""Regression test for bug where sqrt(3) was classified as integer"""
f = ps.fields("f: [1D]")
tmp = sp.symbols("tmp")
assignments = [ps.Assignment(tmp, sp.sqrt(3)),
ps.Assignment(f[0], tmp)]
arr = np.array([1], dtype=np.float64)
kernel = ps.create_kernel(assignments).compile()
kernel(f=arr)
assert 1.7 < arr[0] < 1.8
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment