Skip to content
Snippets Groups Projects
Commit 43f6d5de authored by Stephan Seitz's avatar Stephan Seitz Committed by Stephan Seitz
Browse files

Use get_type_of_expression in typing_form_sympy_inspection to infer types

parent d6301eea
Branches
Tags
1 merge request!43Use get_type_of_expression in typing_form_sympy_inspection to infer types
import os import os
from collections import Hashable
from functools import partial
from itertools import chain
try: try:
from functools import lru_cache as memorycache from functools import lru_cache as memorycache
except ImportError: except ImportError:
from backports.functools_lru_cache import lru_cache as memorycache from backports.functools_lru_cache import lru_cache as memorycache
try: try:
from joblib import Memory from joblib import Memory
from appdirs import user_cache_dir from appdirs import user_cache_dir
...@@ -22,6 +26,20 @@ except ImportError: ...@@ -22,6 +26,20 @@ except ImportError:
return o return o
def _wrapper(wrapped_func, cached_func, *args, **kwargs):
if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())):
return cached_func(*args, **kwargs)
else:
return wrapped_func(*args, **kwargs)
def memorycache_if_hashable(maxsize=128, typed=False):
def wrapper(func):
return partial(_wrapper, func, memorycache(maxsize, typed)(func))
return wrapper
# Disable memory cache: # Disable memory cache:
# disk_cache = lambda o: o # disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o # disk_cache_no_fallback = lambda o: o
import ctypes import ctypes
from collections import defaultdict
from functools import partial
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.core.cache import cacheit from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean from sympy.logic.boolalg import Boolean
from pystencils.cache import memorycache from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal from pystencils.utils import all_equal
try: try:
...@@ -408,11 +410,22 @@ def collate_types(types): ...@@ -408,11 +410,22 @@ def collate_types(types):
return result return result
@memorycache(maxsize=2048) @memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'): def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr) expr = sp.sympify(expr)
if isinstance(expr, sp.Integer): if isinstance(expr, sp.Integer):
return create_type(default_int_type) return create_type(default_int_type)
...@@ -423,14 +436,15 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type=' ...@@ -423,14 +436,15 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, TypedSymbol): elif isinstance(expr, TypedSymbol):
return expr.dtype return expr.dtype
elif isinstance(expr, sp.Symbol): elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed! ", str(expr)) return symbol_type_dict[expr.name]
# raise ValueError("All symbols iside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func): elif isinstance(expr, cast_func):
return expr.args[1] return expr.args[1]
elif isinstance(expr, vec_any) or isinstance(expr, vec_all): elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool") return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise: elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args)) collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args)) collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType: if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width) collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type return collated_result_type
...@@ -440,16 +454,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type=' ...@@ -440,16 +454,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction): elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean # if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool") result = create_type("bool")
vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)] vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args: if vec_args:
result = VectorType(result, width=vec_args[0].width) result = VectorType(result, width=vec_args[0].width)
return result return result
elif isinstance(expr, sp.Pow): elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
return get_type_of_expression(expr.args[0]) return get_type(expr.args[0])
elif isinstance(expr, sp.Expr): elif isinstance(expr, sp.Expr):
expr: sp.Expr expr: sp.Expr
if expr.args: if expr.args:
types = tuple(get_type_of_expression(a) for a in expr.args) types = tuple(get_type(a) for a in expr.args)
return collate_types(types) return collate_types(types)
else: else:
if expr.is_integer: if expr.is_integer:
......
from sympy.abc import a, b, c, d, e, f
import pystencils
from pystencils.data_types import cast_func, create_type
def test_type_interference():
x = pystencils.fields('x: float32[3d]')
assignments = pystencils.AssignmentCollection({
a: cast_func(10, create_type('float64')),
b: cast_func(10, create_type('uint16')),
e: 11,
c: b,
f: c + b,
d: c + b + x.center + e,
x.center: c + b + x.center
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.show_code(ast))
print(code)
assert 'double a' in code
assert 'uint16_t b' in code
assert 'uint16_t f' in code
assert 'int64_t e' in code
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment