Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (8)
......@@ -465,11 +465,13 @@ class LoopOverCoordinate(Node):
@staticmethod
def get_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int')
return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True)
@staticmethod
def get_block_loop_counter_symbol(coordinate_to_loop_over):
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int')
return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over),
'int',
nonnegative=True)
@property
def loop_counter_symbol(self):
......@@ -503,7 +505,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = sp.simplify(rhs_expr)
self.rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
......
......@@ -42,12 +42,11 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke
Returns:
AST node representing a function, that can be printed as C or CUDA code
"""
def type_symbol(term):
if isinstance(term, Field.Access) or isinstance(term, TypedSymbol):
return term
elif isinstance(term, sp.Symbol):
if not hasattr(type_info, '__getitem__'):
if isinstance(type_info, str) or not hasattr(type_info, '__getitem__'):
return TypedSymbol(term.name, create_type(type_info))
else:
return TypedSymbol(term.name, type_info[term.name])
......
......@@ -236,6 +236,14 @@ class TypedSymbol(sp.Symbol):
def __getnewargs__(self):
return self.name, self.dtype
@property
def canonical(self):
return self
@property
def reversed(self):
return self
def create_type(specification):
"""Creates a subclass of Type according to a string or an object of subclass Type.
......
......@@ -111,6 +111,7 @@ class AssignmentCollection:
"Not in SSA form - same symbol assigned multiple times"
return bound_symbols_set
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
......
......@@ -46,6 +46,7 @@ def test_inplace_update():
kernel(f=arr)
np.testing.assert_equal(arr, 2)
def test_vectorization_fixed_size():
configurations = []
# Fixed size - multiple of four
......