Skip to content
Snippets Groups Projects
Commit 43f6d5de authored by Stephan Seitz's avatar Stephan Seitz Committed by Stephan Seitz
Browse files

Use get_type_of_expression in typing_form_sympy_inspection to infer types

parent d6301eea
No related branches found
No related tags found
1 merge request!43Use get_type_of_expression in typing_form_sympy_inspection to infer types
import os
from collections import Hashable
from functools import partial
from itertools import chain
try:
from functools import lru_cache as memorycache
except ImportError:
from backports.functools_lru_cache import lru_cache as memorycache
try:
from joblib import Memory
from appdirs import user_cache_dir
......@@ -22,6 +26,20 @@ except ImportError:
return o
def _wrapper(wrapped_func, cached_func, *args, **kwargs):
if all(isinstance(a, Hashable) for a in chain(args, kwargs.values())):
return cached_func(*args, **kwargs)
else:
return wrapped_func(*args, **kwargs)
def memorycache_if_hashable(maxsize=128, typed=False):
def wrapper(func):
return partial(_wrapper, func, memorycache(maxsize, typed)(func))
return wrapper
# Disable memory cache:
# disk_cache = lambda o: o
# disk_cache_no_fallback = lambda o: o
import ctypes
from collections import defaultdict
from functools import partial
import numpy as np
import sympy as sp
from sympy.core.cache import cacheit
from sympy.logic.boolalg import Boolean
from pystencils.cache import memorycache
from pystencils.cache import memorycache, memorycache_if_hashable
from pystencils.utils import all_equal
try:
......@@ -408,11 +410,22 @@ def collate_types(types):
return result
@memorycache(maxsize=2048)
def get_type_of_expression(expr, default_float_type='double', default_int_type='int'):
@memorycache_if_hashable(maxsize=2048)
def get_type_of_expression(expr,
default_float_type='double',
default_int_type='int',
symbol_type_dict=None):
from pystencils.astnodes import ResolvedFieldAccess
from pystencils.cpu.vectorization import vec_all, vec_any
if not symbol_type_dict:
symbol_type_dict = defaultdict(lambda: create_type('double'))
get_type = partial(get_type_of_expression,
default_float_type=default_float_type,
default_int_type=default_int_type,
symbol_type_dict=symbol_type_dict)
expr = sp.sympify(expr)
if isinstance(expr, sp.Integer):
return create_type(default_int_type)
......@@ -423,14 +436,15 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, TypedSymbol):
return expr.dtype
elif isinstance(expr, sp.Symbol):
raise ValueError("All symbols inside this expression have to be typed! ", str(expr))
return symbol_type_dict[expr.name]
# raise ValueError("All symbols iside this expression have to be typed! ", str(expr))
elif isinstance(expr, cast_func):
return expr.args[1]
elif isinstance(expr, vec_any) or isinstance(expr, vec_all):
elif isinstance(expr, (vec_any, vec_all)):
return create_type("bool")
elif hasattr(expr, 'func') and expr.func == sp.Piecewise:
collated_result_type = collate_types(tuple(get_type_of_expression(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type_of_expression(a[1]) for a in expr.args))
collated_result_type = collate_types(tuple(get_type(a[0]) for a in expr.args))
collated_condition_type = collate_types(tuple(get_type(a[1]) for a in expr.args))
if type(collated_condition_type) is VectorType and type(collated_result_type) is not VectorType:
collated_result_type = VectorType(collated_result_type, width=collated_condition_type.width)
return collated_result_type
......@@ -440,16 +454,16 @@ def get_type_of_expression(expr, default_float_type='double', default_int_type='
elif isinstance(expr, sp.boolalg.Boolean) or isinstance(expr, sp.boolalg.BooleanFunction):
# if any arg is of vector type return a vector boolean, else return a normal scalar boolean
result = create_type("bool")
vec_args = [get_type_of_expression(a) for a in expr.args if isinstance(get_type_of_expression(a), VectorType)]
vec_args = [get_type(a) for a in expr.args if isinstance(get_type(a), VectorType)]
if vec_args:
result = VectorType(result, width=vec_args[0].width)
return result
elif isinstance(expr, sp.Pow):
return get_type_of_expression(expr.args[0])
elif isinstance(expr, (sp.Pow, sp.Sum, sp.Product)):
return get_type(expr.args[0])
elif isinstance(expr, sp.Expr):
expr: sp.Expr
if expr.args:
types = tuple(get_type_of_expression(a) for a in expr.args)
types = tuple(get_type(a) for a in expr.args)
return collate_types(types)
else:
if expr.is_integer:
......
from sympy.abc import a, b, c, d, e, f
import pystencils
from pystencils.data_types import cast_func, create_type
def test_type_interference():
x = pystencils.fields('x: float32[3d]')
assignments = pystencils.AssignmentCollection({
a: cast_func(10, create_type('float64')),
b: cast_func(10, create_type('uint16')),
e: 11,
c: b,
f: c + b,
d: c + b + x.center + e,
x.center: c + b + x.center
})
ast = pystencils.create_kernel(assignments)
code = str(pystencils.show_code(ast))
print(code)
assert 'double a' in code
assert 'uint16_t b' in code
assert 'uint16_t f' in code
assert 'int64_t e' in code
......@@ -147,7 +147,10 @@ def get_common_shape(field_set):
return shape
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
def make_loop_over_domain(body,
iteration_slice=None,
ghost_layers=None,
loop_order=None):
"""Uses :class:`pystencils.field.Field.Access` to create (multiple) loops around given AST.
Args:
......@@ -189,17 +192,21 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate,
begin, end, 1)
current_body = ast.Block([new_loop])
else:
slice_component = iteration_slice[loop_coordinate]
if type(slice_component) is slice:
sc = slice_component
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, sc.start, sc.stop, sc.step)
new_loop = ast.LoopOverCoordinate(current_body,
loop_coordinate, sc.start,
sc.stop, sc.step)
current_body = ast.Block([new_loop])
else:
assignment = ast.SympyAssignment(ast.LoopOverCoordinate.get_loop_counter_symbol(loop_coordinate),
sp.sympify(slice_component))
assignment = ast.SympyAssignment(
ast.LoopOverCoordinate.get_loop_counter_symbol(
loop_coordinate), sp.sympify(slice_component))
current_body.insert_front(assignment)
return current_body, ghost_layers
......@@ -238,9 +245,11 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
offset += field.strides[coordinate_id] * coordinate_value
if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
offset += field.strides[coordinate_id] * field_access.offsets[
coordinate_id]
if type(field_access.offsets[coordinate_id]) is int:
name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
name += "_%d%d" % (coordinate_id,
field_access.offsets[coordinate_id])
else:
list_to_hash.append(field_access.offsets[coordinate_id])
else:
......@@ -257,7 +266,8 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
return new_ptr, offset
def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dimensions, index_dimensions):
def parse_base_pointer_info(base_pointer_specification, loop_order,
spatial_dimensions, index_dimensions):
"""
Creates base pointer specification for :func:`resolve_field_accesses` function.
......@@ -298,8 +308,10 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
raise ValueError("Coordinate %d does not exist" % (elem, ))
new_group.append(elem)
if elem in specified_coordinates:
raise ValueError("Coordinate %d specified two times" % (elem,))
raise ValueError("Coordinate %d specified two times" %
(elem, ))
specified_coordinates.add(elem)
for element in spec_group:
if type(element) is int:
add_new_element(element)
......@@ -345,30 +357,42 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
base buffer index - required by 'resolve_buffer_accesses' function
"""
if loop_counters is None or loop_iterations is None:
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)]
loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, ast.SympyAssignment)
]
loops.reverse()
parents_of_innermost_loop = list(parents_of_type(loops[0], ast.LoopOverCoordinate, include_current=True))
parents_of_innermost_loop = list(
parents_of_type(loops[0],
ast.LoopOverCoordinate,
include_current=True))
assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
assert all(l1 is l2
for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops]
loop_counters = [l.loop_counter_symbol for l in loops]
field_accesses = ast_node.atoms(AbstractField.AbstractAccess)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
buffer_accesses = {
fa
for fa in field_accesses if FieldType.is_buffer(fa.field)
}
loop_counters = [v * len(buffer_accesses) for v in loop_counters]
base_buffer_index = loop_counters[0]
stride = 1
for idx, var in enumerate(loop_counters[1:]):
cur_stride = loop_iterations[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
stride *= int(cur_stride) if isinstance(cur_stride,
float) else cur_stride
base_buffer_index += var * stride
return base_buffer_index
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()):
def resolve_buffer_accesses(ast_node,
base_buffer_index,
read_only_field_names=set()):
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess):
field_access = expr
......@@ -378,17 +402,24 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return expr
buffer = field_access.field
field_ptr = FieldPointerSymbol(buffer.name, buffer.dtype, const=buffer.name in read_only_field_names)
field_ptr = FieldPointerSymbol(
buffer.name,
buffer.dtype,
const=buffer.name in read_only_field_names)
buffer_index = base_buffer_index
if len(field_access.index) > 1:
raise RuntimeError('Only indexing dimensions up to 1 are currently supported in buffers!')
raise RuntimeError(
'Only indexing dimensions up to 1 are currently supported in buffers!'
)
if len(field_access.index) > 0:
cell_index = field_access.index[0]
buffer_index += cell_index
result = ast.ResolvedFieldAccess(field_ptr, buffer_index, field_access.field, field_access.offsets,
result = ast.ResolvedFieldAccess(field_ptr, buffer_index,
field_access.field,
field_access.offsets,
field_access.index)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
......@@ -396,16 +427,23 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
new_args = [
visit_sympy_expr(e, enclosing_block, sympy_assignment)
for e in expr.args
]
kwargs = {
'evaluate': False
} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
sub_ast)
else:
for i, a in enumerate(sub_ast.args):
visit_node(a)
......@@ -413,7 +451,8 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(),
def resolve_field_accesses(ast_node,
read_only_field_names=set(),
field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})):
"""
......@@ -430,8 +469,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns
transformed AST
"""
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
field_to_base_pointer_info = OrderedDict(
sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(
sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess):
......@@ -439,20 +480,29 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field = field_access.field
if field_access.indirect_addressing_fields:
new_offsets = tuple(visit_sympy_expr(off, enclosing_block, sympy_assignment)
new_offsets = tuple(
visit_sympy_expr(off, enclosing_block, sympy_assignment)
for off in field_access.offsets)
new_indices = tuple(visit_sympy_expr(ind, enclosing_block, sympy_assignment)
if isinstance(ind, sp.Basic) else ind
new_indices = tuple(
visit_sympy_expr(ind, enclosing_block, sympy_assignment
) if isinstance(ind, sp.Basic) else ind
for ind in field_access.index)
field_access = Field.Access(field_access.field, new_offsets,
new_indices, field_access.is_absolute_access)
new_indices,
field_access.is_absolute_access)
if field.name in field_to_base_pointer_info:
base_pointer_info = field_to_base_pointer_info[field.name]
else:
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
base_pointer_info = [
list(
range(field.index_dimensions + field.spatial_dimensions))
]
field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names)
field_ptr = FieldPointerSymbol(
field.name,
field.dtype,
const=field.name in read_only_field_names)
def create_coordinate_dict(group_param):
coordinates = {}
......@@ -460,12 +510,15 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if e < field.spatial_dimensions:
if field.name in field_to_fixed_coordinates:
if not field_access.is_absolute_access:
coordinates[e] = field_to_fixed_coordinates[field.name][e]
coordinates[e] = field_to_fixed_coordinates[
field.name][e]
else:
coordinates[e] = 0
else:
if not field_access.is_absolute_access:
coordinates[e] = ast.LoopOverCoordinate.get_loop_counter_symbol(e)
coordinates[
e] = ast.LoopOverCoordinate.get_loop_counter_symbol(
e)
else:
coordinates[e] = 0
coordinates[e] *= field.dtype.item_size
......@@ -474,9 +527,11 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
assert field.index_dimensions == 1
accessed_field_name = field_access.index[0]
assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
coordinates[e] = field.dtype.get_element_offset(
accessed_field_name)
else:
coordinates[e] = field_access.index[e - field.spatial_dimensions]
coordinates[e] = field_access.index[
e - field.spatial_dimensions]
return coordinates
......@@ -484,19 +539,27 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
for group in reversed(base_pointer_info[1:]):
coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
new_ptr, offset = create_intermediate_base_pointer(
field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False)
enclosing_block.insert_before(new_assignment, sympy_assignment)
new_assignment = ast.SympyAssignment(new_ptr,
last_pointer + offset,
is_const=False)
enclosing_block.insert_before(new_assignment,
sympy_assignment)
last_pointer = new_ptr
coord_dict = create_coordinate_dict(base_pointer_info[0])
_, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
result = ast.ResolvedFieldAccess(last_pointer, offset, field_access.field,
field_access.offsets, field_access.index)
_, offset = create_intermediate_base_pointer(
field_access, coord_dict, last_pointer)
result = ast.ResolvedFieldAccess(last_pointer, offset,
field_access.field,
field_access.offsets,
field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0])
new_type = field_access.field.dtype.get_element_type(
field_access.index[0])
result = reinterpret_cast_func(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment)
......@@ -504,20 +567,28 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(expr, ast.ResolvedFieldAccess):
return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
new_args = [
visit_sympy_expr(e, enclosing_block, sympy_assignment)
for e in expr.args
]
kwargs = {
'evaluate': False
} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr
def visit_node(sub_ast):
if isinstance(sub_ast, ast.SympyAssignment):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block, sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block, sub_ast)
sub_ast.lhs = visit_sympy_expr(sub_ast.lhs, enclosing_block,
sub_ast)
sub_ast.rhs = visit_sympy_expr(sub_ast.rhs, enclosing_block,
sub_ast)
elif isinstance(sub_ast, ast.Conditional):
enclosing_block = sub_ast.parent
assert type(enclosing_block) is ast.Block
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr, enclosing_block, sub_ast)
sub_ast.condition_expr = visit_sympy_expr(sub_ast.condition_expr,
enclosing_block, sub_ast)
visit_node(sub_ast.true_block)
if sub_ast.false_block:
visit_node(sub_ast.false_block)
......@@ -561,11 +632,14 @@ def move_constants_before_loop(ast_node):
element = element.parent
return last_block, last_block_child
def check_if_assignment_already_in_block(assignment, target_block, rhs_or_lhs=True):
def check_if_assignment_already_in_block(assignment,
target_block,
rhs_or_lhs=True):
for arg in target_block.args:
if type(arg) is not ast.SympyAssignment:
continue
if (rhs_or_lhs and arg.rhs == assignment.rhs) or (not rhs_or_lhs and arg.lhs == assignment.lhs):
if (rhs_or_lhs and arg.rhs == assignment.rhs) or (
not rhs_or_lhs and arg.lhs == assignment.lhs):
return arg
return None
......@@ -588,7 +662,8 @@ def move_constants_before_loop(ast_node):
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
if not isinstance(
child, ast.SympyAssignment): # only move SympyAssignments
block.append(child)
continue
......@@ -597,12 +672,14 @@ def move_constants_before_loop(ast_node):
target.append(child)
else:
if isinstance(child, ast.SympyAssignment):
exists_already = check_if_assignment_already_in_block(child, target, False)
exists_already = check_if_assignment_already_in_block(
child, target, False)
else:
exists_already = False
if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True)
rhs_identical = check_if_assignment_already_in_block(
child, target, True)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
......@@ -617,7 +694,9 @@ def move_constants_before_loop(ast_node):
# -> symbol has to be renamed
assert isinstance(child.lhs, TypedSymbol)
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before)
target.insert_before(
ast.SympyAssignment(new_symbol, child.rhs),
child_to_insert_before)
substitute_variables[child.lhs] = new_symbol
......@@ -633,7 +712,9 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
"""
all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
assert len(
inner_loop
) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop]
......@@ -664,28 +745,38 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if not isinstance(symbol, AbstractField.AbstractAccess):
assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts,
shape=(1,))[inner_loop.loop_counter_symbol]
symbols_with_temporary_array[symbol] = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
assignment_group = []
for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items())
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group:
new_rhs = assignment.rhs.subs(
symbols_with_temporary_array.items())
if not isinstance(assignment.lhs, AbstractField.AbstractAccess
) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol]
new_ts = TypedSymbol(assignment.lhs.name,
PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(
new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
else:
new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
assignment_groups.append(assignment_group)
new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups]
new_loops = [
inner_loop.new_loop_with_different_body(ast.Block(group))
for group in assignment_groups
]
inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
for tmp_array in symbols_with_temporary_array:
tmp_array_pointer = TypedSymbol(tmp_array.name, PointerType(tmp_array.dtype))
alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer, inner_loop.stop, inner_loop.start)
tmp_array_pointer = TypedSymbol(tmp_array.name,
PointerType(tmp_array.dtype))
alloc_node = ast.TemporaryMemoryAllocation(tmp_array_pointer,
inner_loop.stop,
inner_loop.start)
free_node = ast.TemporaryMemoryFree(alloc_node)
outer_loop.parent.insert_front(alloc_node)
outer_loop.parent.append(free_node)
......@@ -715,7 +806,8 @@ def cut_loop(loop_node, cutting_points):
elif new_end - new_start == 0:
pass
else:
new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_loop = ast.LoopOverCoordinate(
deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_loops.append(new_loop)
new_start = new_end
......@@ -723,7 +815,8 @@ def cut_loop(loop_node, cutting_points):
return new_loops
def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = False) -> None:
def simplify_conditionals(node: ast.Node,
loop_counter_simplification: bool = False) -> None:
"""Removes conditionals that are always true/false.
Args:
......@@ -739,14 +832,18 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
if conditional.condition_expr == sp.true:
conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
conditional.parent.replace(
conditional,
[conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification:
try:
# noinspection PyUnresolvedReferences
from pystencils.integer_set_analysis import simplify_loop_counter_dependent_conditional
simplify_loop_counter_dependent_conditional(conditional)
except ImportError:
warnings.warn("Integer simplifications in conditionals skipped, because ISLpy package not installed")
warnings.warn(
"Integer simplifications in conditionals skipped, because ISLpy package not installed"
)
def cleanup_blocks(node: ast.Node) -> None:
......@@ -808,18 +905,28 @@ class KernelConstraintsCheck:
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
new_args = [
self.process_expression(arg, type_constants)
if arg not in (-1, 1) else arg for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
elif isinstance(rhs, cast_func):
return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype)
return cast_func(
self.process_expression(rhs.args[0], type_constants=False),
rhs.dtype)
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
return sp.Pow(
self.process_expression(rhs.args[0], type_constants),
rhs.args[1])
else:
new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
new_args = [
self.process_expression(arg, type_constants)
for arg in rhs.args
]
return rhs.func(*new_args) if new_args else rhs
@property
......@@ -829,7 +936,7 @@ class KernelConstraintsCheck:
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
......@@ -839,22 +946,32 @@ class KernelConstraintsCheck:
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
raise ValueError(
"Field {} is written at two different locations".format(
lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
raise ValueError(
"Assignments not in SSA form, multiple assignments to {}".
format(lhs.name))
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
raise ValueError(
"Symbol {} is written, after it has been read".format(
lhs.name))
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
if isinstance(rhs, AbstractField.AbstractAccess
) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(
rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError("Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
raise ValueError(
"Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(
rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
......@@ -875,21 +992,29 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
if isinstance(type_for_symbol,
str) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
# assignments = ast.Block(eqs).atoms(ast.Assignment)
# type_for_symbol.update( {a.lhs: get_type_of_expression(a.rhs) for a in assignments})
# print(type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol,
check_independence_condition)
def visit(obj):
if isinstance(obj, list) or isinstance(obj, tuple):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
false_block = None if obj.false_block is None else visit(obj.false_block)
result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
false_block = None if obj.false_block is None else visit(
obj.false_block)
result = ast.Conditional(check.process_expression(
obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block),
false_block=false_block)
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
......@@ -897,7 +1022,8 @@ def add_types(eqs, type_for_symbol, check_independence_condition):
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
elif isinstance(
obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
......@@ -956,7 +1082,8 @@ def insert_casts(node):
for arg in node.args:
args.append(insert_casts(arg))
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne,
sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
......@@ -974,7 +1101,8 @@ def insert_casts(node):
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
return node.func(
lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func is ast.Block:
......@@ -991,19 +1119,30 @@ def insert_casts(node):
target = collate_types(types)
zipped = list(zip(expressions, types))
casted_expressions = cast(zipped, target)
args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]
args = [
arg.func(*[expr, arg.cond])
for (arg, expr) in zip(args, casted_expressions)
]
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction
) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
all_inner_loops = [
l for l in function_node.atoms(ast.LoopOverCoordinate)
if l.is_innermost_loop
]
assert len(
all_inner_loops
) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
for loop in parents_of_type(inner_loop,
ast.LoopOverCoordinate,
include_current=True):
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True)
......@@ -1016,7 +1155,7 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
# --------------------------------------- Helper Functions -------------------------------------------------------------
def typing_from_sympy_inspection(eqs, default_type="double"):
def typing_from_sympy_inspection(eqs, default_type="double", default_int_type='int64'):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
......@@ -1032,17 +1171,25 @@ def typing_from_sympy_inspection(eqs, default_type="double"):
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
if eq.false_block:
result.update(typing_from_sympy_inspection(eq.false_block.args))
result.update(typing_from_sympy_inspection(
eq.false_block.args))
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue
else:
from pystencils.cpu.vectorization import vec_all, vec_any
if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any):
if isinstance(eq.rhs, (vec_all, vec_any)):
result[eq.lhs.name] = "bool"
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
try:
result[eq.lhs.name] = get_type_of_expression(eq.rhs,
default_float_type=default_type,
default_int_type=default_int_type,
symbol_type_dict=result)
except Exception:
pass # gracefully fail in case get_type_of_expression cannot determine type
return result
......@@ -1084,13 +1231,17 @@ def get_optimal_loop_ordering(fields):
ref_field = next(iter(fields))
for field in fields:
if field.spatial_dimensions != ref_field.spatial_dimensions:
raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape for f in fields}))
raise ValueError(
"All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape
for f in fields}))
layouts = set([field.layout for field in fields])
if len(layouts) > 1:
raise ValueError("Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout for f in fields}))
raise ValueError(
"Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout
for f in fields}))
layout = list(layouts)[0]
return list(layout)
......@@ -1135,7 +1286,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
"""
inner_loops = []
inner_loop_counters = set()
for loop in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment):
for loop in filtered_tree_iteration(ast_node,
ast.LoopOverCoordinate,
stop_type=ast.SympyAssignment):
if loop.is_innermost_loop:
inner_loops.append(loop)
inner_loop_counters.add(loop.coordinate_to_loop_over)
......@@ -1146,8 +1299,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
inner_loop_counter = inner_loop_counters.pop()
parameters = ast_node.get_parameters()
stride_params = [p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter]
stride_params = [
p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
]
subs_dict = {stride_param: 1 for stride_param in stride_params}
if subs_dict:
ast_node.subs(subs_dict)
......@@ -1163,7 +1318,10 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
Returns:
number of dimensions blocked
"""
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)]
loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
]
body = ast_node.body
coordinates = []
......@@ -1183,8 +1341,12 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
outer_loop = None
for coord in reversed(coordinates):
body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body, coord, loop_starts[coord], loop_stops[coord],
step=block_size[coord], is_block_loop=True)
outer_loop = ast.LoopOverCoordinate(body,
coord,
loop_starts[coord],
loop_stops[coord],
step=block_size[coord],
is_block_loop=True)
ast_node.body = ast.Block([outer_loop])
......@@ -1193,7 +1355,8 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
coord = inner_loop.coordinate_to_loop_over
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(loop_range).is_number and loop_range % block_size[coord] == 0:
if sp.sympify(
loop_range).is_number and loop_range % block_size[coord] == 0:
stop = block_ctr + block_size[coord]
else:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment