diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index b9f13ae641aae01c8f6e91d29a88586a8e68012b..4d9a77e479ca1602d67157c56f7cad197334281e 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -9,7 +9,7 @@ import pystencils from pystencils.typing import TypedSymbol, CastFunc, create_type, get_next_parent_of_type from pystencils.enums import Target, Backend from pystencils.field import Field -from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol +from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.sympyextensions import fast_subs NodeOrExpr = Union['Node', sp.Expr] diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 7b2719fd7b0889aae64794c8e97231ba6e89d241..f2dc0ff933a1e51d6bfc67cc0b90963b478f010d 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -59,6 +59,8 @@ def create_kernel(assignments: AssignmentOrAstNodeList, function_name: str = "ke else: raise ValueError("Term has to be field access or symbol") + # TODO 1) check kernel + # TODO 2) add leaf types fields_read, fields_written, assignments = add_types( assignments, type_info, not skip_independence_check, check_double_write_condition=not allow_double_writes) all_fields = fields_read.union(fields_written) diff --git a/pystencils/field.py b/pystencils/field.py index 146a1cacb8a25561025247b45b1b70368283eb40..91b33eed39c45998256b9e58e1705404e1f7d437 100644 --- a/pystencils/field.py +++ b/pystencils/field.py @@ -19,7 +19,7 @@ from pystencils.stencil import ( direction_string_to_offset, inverse_direction, offset_to_direction_string) from pystencils.sympyextensions import is_integer_sequence -__all__ = ['Field', 'fields', 'FieldType', 'AbstractField'] +__all__ = ['Field', 'fields', 'FieldType', 'Field'] class FieldType(Enum): @@ -137,13 +137,7 @@ def fields(description=None, index_dimensions=0, layout=None, field_type=FieldTy return result -# TODO why this??? Why abstarct? -class AbstractField: - class AbstractAccess: - pass - - -class Field(AbstractField): +class Field: """ With fields one can formulate stencil-like update rules on structured grids. This Field class knows about the dimension, memory layout (strides) and optionally about the size of an array. @@ -625,7 +619,7 @@ class Field(AbstractField): self.coordinate_origin = -sp.Matrix([i / 2 for i in self.spatial_shape]) # noinspection PyAttributeOutsideInit,PyUnresolvedReferences - class Access(TypedSymbol, AbstractField.AbstractAccess): + class Access(TypedSymbol, Field.Access): """Class representing a relative access into a `Field`. This class behaves like a normal sympy Symbol, it is actually derived from it. One can built up diff --git a/pystencils/kernel_contrains_check.py b/pystencils/kernel_contrains_check.py index 55f141201e9a65e13185a17ba058b13ab0f4d10f..842e70ad93cbdc3cd16710c212ecfd51b71b4456 100644 --- a/pystencils/kernel_contrains_check.py +++ b/pystencils/kernel_contrains_check.py @@ -1,20 +1,17 @@ from collections import namedtuple, defaultdict +from typing import Union -import numpy as np - -import pystencils.integer_functions import sympy as sp +from sympy.codegen import Assignment + from pystencils import astnodes as ast, TypedSymbol -from pystencils.bit_masks import flag_cond -from pystencils.field import AbstractField +from pystencils.field import Field from pystencils.transformations import NestedScopes -from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types -from sympy.logic.boolalg import BooleanFunction class KernelConstraintsCheck: - # TODO: Logs # TODO: specification + # TODO: More checks :) """Checks if the input to create_kernel is valid. Test the following conditions: @@ -33,100 +30,41 @@ class KernelConstraintsCheck: self._type_for_symbol = type_for_symbol self.scopes = NestedScopes() - self._field_writes = defaultdict(set) + self.field_writes = defaultdict(set) self.fields_read = set() self.check_independence_condition = check_independence_condition self.check_double_write_condition = check_double_write_condition - def process_assignment(self, assignment): + def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]): # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 - new_rhs = self.process_expression(assignment.rhs) - new_lhs = self._process_lhs(assignment.lhs) - return ast.SympyAssignment(new_lhs, new_rhs) + self.process_expression(assignment.rhs) + self.process_lhs(assignment.lhs) def process_expression(self, rhs, type_constants=True): - - self._update_accesses_rhs(rhs) - if isinstance(rhs, AbstractField.AbstractAccess): + self.update_accesses_rhs(rhs) + if isinstance(rhs, Field.Access): self.fields_read.add(rhs.field) self.fields_read.update(rhs.indirect_addressing_fields) - return rhs - # TODO remove this - #elif isinstance(rhs, ImaginaryUnit): - # return TypedImaginaryUnit(create_type(self._type_for_symbol['_complex_type'])) - elif isinstance(rhs, TypedSymbol): - return rhs - elif isinstance(rhs, sp.Symbol): - return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) - elif type_constants and isinstance(rhs, np.generic): - return CastFunc(rhs, create_type(rhs.dtype)) - elif type_constants and isinstance(rhs, sp.Number): - return CastFunc(rhs, create_type(self._type_for_symbol['_constant'])) - # Very important that this clause comes before BooleanFunction - elif isinstance(rhs, sp.Equality): - if isinstance(rhs.args[1], sp.Number): - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - rhs.args[1]) - else: - return sp.Equality( - self.process_expression(rhs.args[0], type_constants), - self.process_expression(rhs.args[1], type_constants)) - elif isinstance(rhs, CastFunc): - return CastFunc( - self.process_expression(rhs.args[0], type_constants=False), - rhs.dtype) - elif isinstance(rhs, BooleanFunction) or \ - type(rhs) in pystencils.integer_functions.__dict__.values(): - new_args = [self.process_expression(a, type_constants) for a in rhs.args] - types_of_expressions = [get_type_of_expression(a) for a in new_args] - arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True) - new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type - else CastFunc(a, arg_type) - for a in new_args] - return rhs.func(*new_args) - elif isinstance(rhs, flag_cond): - # do not process the arguments to the bit shift - they must remain integers - processed_args = (self.process_expression(a) for a in rhs.args[2:]) - return flag_cond(rhs.args[0], rhs.args[1], *processed_args) - elif isinstance(rhs, sp.Mul): - new_args = [ - self.process_expression(arg, type_constants) - if arg not in (-1, 1) else arg for arg in rhs.args - ] - return rhs.func(*new_args) if new_args else rhs - elif isinstance(rhs, sp.Indexed): - return rhs else: - if isinstance(rhs, sp.Pow): - # don't process exponents -> they should remain integers - return sp.Pow( - self.process_expression(rhs.args[0], type_constants), - rhs.args[1]) - else: - new_args = [ - self.process_expression(arg, type_constants) - for arg in rhs.args - ] - return rhs.func(*new_args) if new_args else rhs + for arg in rhs.args: + self.process_expression(arg, type_constants) @property def fields_written(self): - return set(k.field for k, v in self._field_writes.items() if len(v)) + """ + Return all rhs fields + """ + return set(k.field for k, v in self.field_writes.items() if len(v)) - def _process_lhs(self, lhs): + def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]): assert isinstance(lhs, sp.Symbol) - self._update_accesses_lhs(lhs) - if not isinstance(lhs, (AbstractField.AbstractAccess, TypedSymbol)): - return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) - else: - return lhs + self.update_accesses_lhs(lhs) - def _update_accesses_lhs(self, lhs): - if isinstance(lhs, AbstractField.AbstractAccess): + def update_accesses_lhs(self, lhs): + if isinstance(lhs, Field.Access): fai = self.FieldAndIndex(lhs.field, lhs.index) - self._field_writes[fai].add(lhs.offsets) - if self.check_double_write_condition and len(self._field_writes[fai]) > 1: + self.field_writes[fai].add(lhs.offsets) + if self.check_double_write_condition and len(self.field_writes[fai]) > 1: raise ValueError( f"Field {lhs.field.name} is written at two different locations") elif isinstance(lhs, sp.Symbol): @@ -136,15 +74,15 @@ class KernelConstraintsCheck: raise ValueError(f"Symbol {lhs.name} is written, after it has been read") self.scopes.define_symbol(lhs) - def _update_accesses_rhs(self, rhs): - if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition: - writes = self._field_writes[self.FieldAndIndex( + def update_accesses_rhs(self, rhs): + if isinstance(rhs, Field.Access) and self.check_independence_condition: + writes = self.field_writes[self.FieldAndIndex( rhs.field, rhs.index)] for write_offset in writes: assert len(writes) == 1 if write_offset != rhs.offsets: - raise ValueError("Violation of loop independence condition. Field " - "{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset)) + raise ValueError(f"Violation of loop independence condition. Field " + f"{rhs.field} is read at {rhs.offsets} and written at {write_offset}") self.fields_read.add(rhs.field) elif isinstance(rhs, sp.Symbol): - self.scopes.access_symbol(rhs) \ No newline at end of file + self.scopes.access_symbol(rhs) diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index e3b9ed90730a917a5e7d09f12759d596d23af858..a67b01e35c438cb5e2b0cd49b4c6f7cfe9c03878 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -373,7 +373,15 @@ def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclus Returns: AST, see `create_kernel` """ - assert 'iteration_slice' not in kwargs and 'ghost_layers' not in kwargs and 'omp_single_loop' not in kwargs + if 'ghost_layers' in kwargs: + assert kwargs['ghost_layers'] is None + del kwargs['ghost_layers'] + if 'iteration_slice' in kwargs: + assert kwargs['iteration_slice'] is None + del kwargs['iteration_slice'] + if 'omp_single_loop' in kwargs: + assert kwargs['omp_single_loop'] is False + del kwargs['omp_single_loop'] if isinstance(assignments, AssignmentCollection): subexpressions = assignments.subexpressions + [a for a in assignments.main_assignments @@ -476,6 +484,9 @@ def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclus remove_start_conditional = any([gl[0] == 0 for gl in ghost_layers]) prepend_optimizations = [lambda ast: remove_conditionals_in_staggered_kernel(ast, remove_start_conditional), move_constants_before_loop] + if 'cpu_prepend_optimizations' in kwargs: + prepend_optimizations += kwargs['cpu_prepend_optimizations'] + del kwargs['cpu_prepend_optimizations'] ast = create_kernel(final_assignments, ghost_layers=ghost_layers, target=target, omp_single_loop=False, cpu_prepend_optimizations=prepend_optimizations, **kwargs) return ast diff --git a/pystencils/leaf_typing.py b/pystencils/leaf_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..789bb4a8d8601e6a8cbabb5c87277c9e3ddc15c9 --- /dev/null +++ b/pystencils/leaf_typing.py @@ -0,0 +1,129 @@ +from collections import namedtuple, defaultdict +from typing import List, Union + +import numpy as np + +import pystencils.integer_functions +import sympy as sp + +from pystencils import astnodes as ast, TypedSymbol +from pystencils.bit_masks import flag_cond +from pystencils.field import Field +from pystencils.transformations import NestedScopes +from pystencils.typing import CastFunc, create_type, get_type_of_expression, collate_types +from sympy.codegen import Assignment +from sympy.logic.boolalg import BooleanFunction + + +class KernelConstraintsCheck: # TODO rename + # TODO: Logs + # TODO: specification + # TODO: split this into checker and leaf typing + """Checks if the input to create_kernel is valid. + + Test the following conditions: + + - SSA Form for pure symbols: + - Every pure symbol may occur only once as left-hand-side of an assignment + - Every pure symbol that is read, may not be written to later + - Independence / Parallelization condition: + - a field that is written may only be read at exact the same spatial position + + (Pure symbols are symbols that are not Field.Accesses) + """ + FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index']) + + def __init__(self, type_for_symbol, check_independence_condition, check_double_write_condition=True): + self._type_for_symbol = type_for_symbol + + self.scopes = NestedScopes() + self.field_writes = defaultdict(set) + self.fields_read = set() + self.check_independence_condition = check_independence_condition + self.check_double_write_condition = check_double_write_condition + + def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment: + # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 + new_rhs = self.process_expression(assignment.rhs) + new_lhs = self.process_lhs(assignment.lhs) + return ast.SympyAssignment(new_lhs, new_rhs) + + + # Expression + # 1) ask children if they are cocksure about a type + # 1b) Postpone clueless children (see 5) + # cocksure: Children have somewhere type from Field.Access, TypedSymbol, CastFunction or Function^TM + # clueless: Children without Field.Access,... + # 1c) none child is cocksure -> do nothing a return None, wait for recall from parent + # 2) collate_type of children + # 3) apply collated type on children + # 4) issue warnings of casts on cocksure children + # 5a) resume on clueless children with the collated type as default datatype, issue warning + # 5b) or apply special circumstances + + def process_expression(self, rhs, type_constants=True): # TODO default_type as parameter + if isinstance(rhs, Field.Access): + return rhs + elif isinstance(rhs, TypedSymbol): + return rhs + elif isinstance(rhs, sp.Symbol): + return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name]) + elif type_constants and isinstance(rhs, np.generic): + assert False, f'Why do we have a np.generic in rhs???? {rhs}' + # return CastFunc(rhs, create_type(rhs.dtype)) + elif type_constants and isinstance(rhs, sp.Number): + return CastFunc(rhs, create_type(self._type_for_symbol['_constant'])) + # Very important that this clause comes before BooleanFunction + elif isinstance(rhs, sp.Equality): + if isinstance(rhs.args[1], sp.Number): + return sp.Equality( + self.process_expression(rhs.args[0], type_constants), + rhs.args[1]) # TODO: process args[1] as number with a good type + else: + return sp.Equality( + self.process_expression(rhs.args[0], type_constants), + self.process_expression(rhs.args[1], type_constants)) + elif isinstance(rhs, CastFunc): + return CastFunc( + self.process_expression(rhs.args[0], type_constants=False), # TODO: recommend type + rhs.dtype) + elif isinstance(rhs, BooleanFunction) or \ + type(rhs) in pystencils.integer_functions.__dict__.values(): + new_args = [self.process_expression(a, type_constants) for a in rhs.args] # TODO: recommend type + types_of_expressions = [get_type_of_expression(a) for a in new_args] + arg_type = collate_types(types_of_expressions, forbid_collation_to_float=True) # TODO: this must go + new_args = [a if not hasattr(a, 'dtype') or a.dtype == arg_type + else CastFunc(a, arg_type) + for a in new_args] + return rhs.func(*new_args) + elif isinstance(rhs, flag_cond): # TODO + # do not process the arguments to the bit shift - they must remain integers + processed_args = (self.process_expression(a) for a in rhs.args[2:]) + return flag_cond(rhs.args[0], rhs.args[1], *processed_args) + elif isinstance(rhs, sp.Mul): + new_args = [ + self.process_expression(arg, type_constants) + if arg not in (-1, 1) else arg for arg in rhs.args + ] + return rhs.func(*new_args) if new_args else rhs + elif isinstance(rhs, sp.Indexed): + return rhs + elif isinstance(rhs, sp.Pow): + # don't process exponents -> they should remain integers # TODO + return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1]) + else: + new_args = [self.process_expression(arg, type_constants) for arg in rhs.args] + return rhs.func(*new_args) if new_args else rhs + + @property + def fields_written(self): + """ + Return all rhs fields + """ + return set(k.field for k, v in self.field_writes.items() if len(v)) + + def process_lhs(self, lhs: Union[Field.Access, TypedSymbol, sp.Symbol]): + if not isinstance(lhs, (Field.Access, TypedSymbol)): + return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name]) + else: + return lhs diff --git a/pystencils/simp/simplifications.py b/pystencils/simp/simplifications.py index 114b86a4013cccc3a74b641b0aabf747dbd3035d..c36ba558e0a788a9d1751f2057d4647104d7e5a1 100644 --- a/pystencils/simp/simplifications.py +++ b/pystencils/simp/simplifications.py @@ -8,7 +8,7 @@ from sympy.codegen.rewriting import ReplaceOptim from pystencils.assignment import Assignment from pystencils.astnodes import Node, SympyAssignment -from pystencils.field import AbstractField, Field +from pystencils.field import Field, Field from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect @@ -164,7 +164,7 @@ def add_subexpressions_for_sums(ac): for eq in ac.all_assignments: search_addends(eq.rhs) - addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] + addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)] new_symbol_gen = ac.subexpression_symbol_generator substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) diff --git a/pystencils/transformations.py b/pystencils/transformations.py index beb5d287eb3e45ab48c704a6cb408afcfc765a2c..e47c80aa8665cb1bc20fc65d85b5cba126fd7e92 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -11,7 +11,7 @@ import pystencils.astnodes as ast from pystencils.assignment import Assignment from pystencils.typing import ( PointerType, StructType, TypedSymbol, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type) -from pystencils.field import AbstractField, Field, FieldType +from pystencils.field import Field, Field, FieldType from pystencils.typing import FieldPointerSymbol from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.slicing import normalize_slice @@ -160,7 +160,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or tuple of loop-node, ghost_layer_info """ # find correct ordering by inspecting participating FieldAccesses - field_accesses = body.atoms(AbstractField.AbstractAccess) + field_accesses = body.atoms(Field.Access) field_accesses = {e for e in field_accesses if not e.is_absolute_access} # exclude accesses to buffers from field_list, because buffers are treated separately @@ -359,7 +359,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): actual_sizes = loop_iterations actual_steps = loop_counters - field_accesses = ast_node.atoms(AbstractField.AbstractAccess) + field_accesses = ast_node.atoms(Field.Access) buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_index_size = len(buffer_accesses) @@ -378,7 +378,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=N read_only_field_names = set() def visit_sympy_expr(expr, enclosing_block, sympy_assignment): - if isinstance(expr, AbstractField.AbstractAccess): + if isinstance(expr, Field.Access): field_access = expr # Do not apply transformation if field is not a buffer @@ -444,7 +444,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None, field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) def visit_sympy_expr(expr, enclosing_block, sympy_assignment): - if isinstance(expr, AbstractField.AbstractAccess): + if isinstance(expr, Field.Access): field_access = expr field = field_access.field @@ -686,13 +686,13 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): if s in assignment_map: # if there is no assignment inside the loop body it is independent already for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): - if not isinstance(new_symbol, AbstractField.AbstractAccess) and \ + if not isinstance(new_symbol, Field.Access) and \ new_symbol not in symbols_with_temporary_array: symbols_to_process.append(new_symbol) symbols_resolved.add(s) for symbol in symbol_group: - if not isinstance(symbol, AbstractField.AbstractAccess): + if not isinstance(symbol, Field.Access): assert type(symbol) is TypedSymbol new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) symbols_with_temporary_array[symbol] = sp.IndexedBase( @@ -703,7 +703,7 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): if assignment.lhs in symbols_resolved: new_rhs = assignment.rhs.subs( symbols_with_temporary_array.items()) - if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: + if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group: assert type(assignment.lhs) is TypedSymbol new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol] diff --git a/pystencils/typing/utilities.py b/pystencils/typing/utilities.py index 8187d929e59336a2b86b66e9d29344f3e20149fa..a7a506f0cbfcd724d0d9f6b0f6e900568f94dda1 100644 --- a/pystencils/typing/utilities.py +++ b/pystencils/typing/utilities.py @@ -161,7 +161,7 @@ def get_type_of_expression(expr, return create_type(default_float_type) elif isinstance(expr, ResolvedFieldAccess): return expr.field.dtype - elif isinstance(expr, pystencils.field.Field.AbstractAccess): + elif isinstance(expr, pystencils.field.Field.Access): return expr.field.dtype elif isinstance(expr, TypedSymbol): return expr.dtype @@ -284,6 +284,7 @@ def add_types(eqs: List[Assignment], type_for_symbol: Dict[sp.Symbol, np.dtype], # TODO what does this do???? # TODO: ask Martin + # TODO: use correct one/rename check = KernelConstraintsCheck(type_for_symbol, check_independence_condition, check_double_write_condition=check_double_write_condition)