Skip to content
Snippets Groups Projects
Commit eda2f772 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Use default_int_type, default_float_type in collate_types

parent 0460532f
Branches
Tags
No related merge requests found
...@@ -4,14 +4,14 @@ from functools import partial ...@@ -4,14 +4,14 @@ from functools import partial
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
import sympy as sp
import sympy.codegen.ast
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
import pystencils import pystencils
import sympy as sp
import sympy.codegen.ast
from pystencils.cache import memorycache, memorycache_if_hashable from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal from pystencils.utils import all_equal
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
try: try:
import llvmlite.ir as ir import llvmlite.ir as ir
...@@ -432,7 +432,9 @@ def peel_off_type(dtype, type_to_peel_off): ...@@ -432,7 +432,9 @@ def peel_off_type(dtype, type_to_peel_off):
def collate_types(types, def collate_types(types,
forbid_collation_to_complex=False, forbid_collation_to_complex=False,
forbid_collation_to_float=False): forbid_collation_to_float=False,
default_float_type='float64',
default_int_type='int64'):
""" """
Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double Takes a sequence of types and returns their "common type" e.g. (float, double, float) -> double
Uses the collation rules from numpy. Uses the collation rules from numpy.
...@@ -443,14 +445,14 @@ def collate_types(types, ...@@ -443,14 +445,14 @@ def collate_types(types,
if not np.issubdtype(t.numpy_dtype, np.complexfloating) if not np.issubdtype(t.numpy_dtype, np.complexfloating)
] ]
if not types: if not types:
return create_type(np.float64) return create_type(default_float_type)
if forbid_collation_to_float: if forbid_collation_to_float:
types = [ types = [
t for t in types if not np.issubdtype(t.numpy_dtype, np.floating) t for t in types if not np.issubdtype(t.numpy_dtype, np.floating)
] ]
if not types: if not types:
return create_type(np.int32) return create_type(default_int_type)
# Pointer arithmetic case i.e. pointer + integer is allowed # Pointer arithmetic case i.e. pointer + integer is allowed
if any(type(t) is PointerType for t in types): if any(type(t) is PointerType for t in types):
...@@ -549,7 +551,9 @@ def get_type_of_expression(expr, ...@@ -549,7 +551,9 @@ def get_type_of_expression(expr,
return collate_types( return collate_types(
types, types,
forbid_collation_to_complex=expr.is_real is True, forbid_collation_to_complex=expr.is_real is True,
forbid_collation_to_float=expr.is_integer is True) forbid_collation_to_float=expr.is_integer is True,
default_float_type=default_float_type,
default_int_type=default_int_type)
else: else:
if expr.is_integer: if expr.is_integer:
return create_type(default_int_type) return create_type(default_int_type)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment