Skip to content
Snippets Groups Projects
Commit c378ca19 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent 27cf4f19
No related branches found
No related tags found
No related merge requests found
......@@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter):
def _typed_number(self, number, dtype):
res = self._print(number)
if dtype.is_float:
if dtype.is_float():
if dtype == self._float_type:
if '.' not in res:
res += ".0f"
......
......@@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
elif nontemporal is True:
nontemporal = all_fields
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float)
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields")
......
......@@ -276,6 +276,8 @@ def collate_types(types):
# now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types)
if any(t.is_float() for t in types):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type)
......@@ -289,10 +291,7 @@ def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
if expr == 1 or expr == -1:
return create_type("int16")
else:
return create_type("int")
return create_type("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double")
elif isinstance(expr, ResolvedFieldAccess):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment