Skip to content
Snippets Groups Projects
Commit c378ca19 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes in vectorization to also support float kernels

parent 27cf4f19
No related branches found
No related tags found
No related merge requests found
...@@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter): ...@@ -262,7 +262,7 @@ class CustomSympyPrinter(CCodePrinter):
def _typed_number(self, number, dtype): def _typed_number(self, number, dtype):
res = self._print(number) res = self._print(number)
if dtype.is_float: if dtype.is_float():
if dtype == self._float_type: if dtype == self._float_type:
if '.' not in res: if '.' not in res:
res += ".0f" res += ".0f"
......
...@@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx', ...@@ -35,7 +35,7 @@ def vectorize(kernel_ast: ast.KernelFunction, instruction_set: str = 'avx',
elif nontemporal is True: elif nontemporal is True:
nontemporal = all_fields nontemporal = all_fields
field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float) field_float_dtypes = set(f.dtype for f in all_fields if f.dtype.is_float())
if len(field_float_dtypes) != 1: if len(field_float_dtypes) != 1:
raise NotImplementedError("Cannot vectorize kernels that contain accesses " raise NotImplementedError("Cannot vectorize kernels that contain accesses "
"to differently typed floating point fields") "to differently typed floating point fields")
......
...@@ -276,6 +276,8 @@ def collate_types(types): ...@@ -276,6 +276,8 @@ def collate_types(types):
# now we should have a list of basic types - struct types are not yet supported # now we should have a list of basic types - struct types are not yet supported
assert all(type(t) is BasicType for t in types) assert all(type(t) is BasicType for t in types)
if any(t.is_float() for t in types):
types = tuple(t for t in types if t.is_float())
# use numpy collation -> create type from numpy type -> and, put vector type around if necessary # use numpy collation -> create type from numpy type -> and, put vector type around if necessary
result_numpy_type = np.result_type(*(t.numpy_dtype for t in types)) result_numpy_type = np.result_type(*(t.numpy_dtype for t in types))
result = BasicType(result_numpy_type) result = BasicType(result_numpy_type)
...@@ -289,10 +291,7 @@ def get_type_of_expression(expr): ...@@ -289,10 +291,7 @@ def get_type_of_expression(expr):
from pystencils.astnodes import ResolvedFieldAccess from pystencils.astnodes import ResolvedFieldAccess
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
if expr == 1 or expr == -1: return create_type("int")
return create_type("int16")
else:
return create_type("int")
elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float): elif isinstance(expr, sp.Rational) or isinstance(expr, sp.Float):
return create_type("double") return create_type("double")
elif isinstance(expr, ResolvedFieldAccess): elif isinstance(expr, ResolvedFieldAccess):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment