diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index 85c6033c513bb42646c76168eb13abc51f9b37ce..ca37e054ad149faae6af9080fe85af62972a1960 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -247,7 +247,7 @@ class KernelFunction(Node): if hasattr(symbol, 'field_name'): return field_map[symbol.field_name], elif hasattr(symbol, 'field_names'): - return tuple(field_map[fn] for fn in symbol.field_names) + return tuple(field_map.get(fn, field_map.get('diff' + fn)) for fn in symbol.field_names) return () argument_symbols = self._body.undefined_symbols - self.global_variables @@ -297,7 +297,7 @@ class Block(Node): except AttributeError: pass - @property + @ property def args(self): return self._nodes @@ -361,7 +361,7 @@ class Block(Node): replacements.parent = self self._nodes.insert(idx, replacements) - @property + @ property def symbols_defined(self): result = set() for a in self.args: @@ -371,7 +371,7 @@ class Block(Node): result.update(a.symbols_defined) return result - @property + @ property def undefined_symbols(self): result = set() defined_symbols = set() @@ -443,7 +443,7 @@ class LoopOverCoordinate(Node): self.step = fast_subs(self.step, subs_dict, skip) return self - @property + @ property def args(self): result = [self.body] for e in [self.start, self.stop, self.step]: @@ -461,11 +461,11 @@ class LoopOverCoordinate(Node): elif child == self.stop: self.stop = replacement - @property + @ property def symbols_defined(self): return {self.loop_counter_symbol} - @property + @ property def undefined_symbols(self): result = self.body.undefined_symbols for possible_symbol in [self.start, self.stop, self.step]: @@ -473,22 +473,22 @@ class LoopOverCoordinate(Node): result.update(possible_symbol.atoms(sp.Symbol)) return result - {self.loop_counter_symbol} - @staticmethod + @ staticmethod def get_loop_counter_name(coordinate_to_loop_over): return f"{LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" - @staticmethod + @ staticmethod def get_block_loop_counter_name(coordinate_to_loop_over): return f"{LoopOverCoordinate.BlOCK_LOOP_COUNTER_NAME_PREFIX}_{coordinate_to_loop_over}" - @property + @ property def loop_counter_name(self): if self.is_block_loop: return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over) else: return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over) - @staticmethod + @ staticmethod def is_loop_counter_symbol(symbol): prefix = LoopOverCoordinate.LOOP_COUNTER_NAME_PREFIX if not symbol.name.startswith(prefix): @@ -498,29 +498,29 @@ class LoopOverCoordinate(Node): coordinate = int(symbol.name[len(prefix) + 1:]) return coordinate - @staticmethod + @ staticmethod def get_loop_counter_symbol(coordinate_to_loop_over): return TypedSymbol(LoopOverCoordinate.get_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True) - @staticmethod + @ staticmethod def get_block_loop_counter_symbol(coordinate_to_loop_over): return TypedSymbol(LoopOverCoordinate.get_block_loop_counter_name(coordinate_to_loop_over), 'int', nonnegative=True) - @property + @ property def loop_counter_symbol(self): if self.is_block_loop: return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over) else: return self.get_loop_counter_symbol(self.coordinate_to_loop_over) - @property + @ property def is_outermost_loop(self): from pystencils.transformations import get_next_parent_of_type return get_next_parent_of_type(self, LoopOverCoordinate) is None - @property + @ property def is_innermost_loop(self): return len(self.atoms(LoopOverCoordinate)) == 0 @@ -552,11 +552,11 @@ class SympyAssignment(Node): return False return True - @property + @ property def lhs(self): return self._lhs_symbol - @lhs.setter + @ lhs.setter def lhs(self, new_value): self._lhs_symbol = new_value self._is_declaration = self.__is_declaration() @@ -572,17 +572,17 @@ class SympyAssignment(Node): except Exception: pass - @property + @ property def args(self): return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] - @property + @ property def symbols_defined(self): if not self._is_declaration: return set() return {self._lhs_symbol} - @property + @ property def undefined_symbols(self): result = {s for s in self.rhs.free_symbols if not isinstance(s, sp.Indexed)} # Add loop counters if there a field accesses @@ -596,11 +596,11 @@ class SympyAssignment(Node): result.update(self._lhs_symbol.atoms(sp.Symbol)) return result - @property + @ property def is_declaration(self): return self._is_declaration - @property + @ property def is_const(self): return self._is_const @@ -657,7 +657,7 @@ class ResolvedFieldAccess(sp.Indexed): super_class_contents = super(ResolvedFieldAccess, self)._hashable_content() return super_class_contents + tuple(self.offsets) + (repr(self.idx_coordinate_values), hash(self.field)) - @property + @ property def typed_symbol(self): return self.base.label @@ -687,18 +687,18 @@ class TemporaryMemoryAllocation(Node): self.headers = ['<stdlib.h>'] self._align_offset = align_offset - @property + @ property def symbols_defined(self): return {self.symbol} - @property + @ property def undefined_symbols(self): if isinstance(self.size, sp.Basic): return self.size.atoms(sp.Symbol) else: return set() - @property + @ property def args(self): return [self.symbol] @@ -714,22 +714,22 @@ class TemporaryMemoryFree(Node): super(TemporaryMemoryFree, self).__init__(parent=None) self.alloc_node = alloc_node - @property + @ property def symbol(self): return self.alloc_node.symbol def offset(self, byte_alignment): return self.alloc_node.offset(byte_alignment) - @property + @ property def symbols_defined(self): return set() - @property + @ property def undefined_symbols(self): return set() - @property + @ property def args(self): return [] @@ -747,15 +747,15 @@ class SourceCodeComment(Node): def __init__(self, text): self.text = text - @property + @ property def args(self): return [] - @property + @ property def symbols_defined(self): return set() - @property + @ property def undefined_symbols(self): return set() @@ -770,15 +770,15 @@ class EmptyLine(Node): def __init__(self): pass - @property + @ property def args(self): return [] - @property + @ property def symbols_defined(self): return set() - @property + @ property def undefined_symbols(self): return set() @@ -798,15 +798,15 @@ class ConditionalFieldAccess(sp.Function): def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0): return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value)) - @property + @ property def access(self): return self.args[0] - @property + @ property def outofbounds_condition(self): return self.args[1] - @property + @ property def outofbounds_value(self): return self.args[2] diff --git a/pystencils/gpucuda/kernelcreation.py b/pystencils/gpucuda/kernelcreation.py index 52a4dc8bd90e6caa8dd462b0ced27ce33e6c1ae2..c7ad0c6892e9b6d3ac4b8920bb63e42379ff7235 100644 --- a/pystencils/gpucuda/kernelcreation.py +++ b/pystencils/gpucuda/kernelcreation.py @@ -15,7 +15,8 @@ def create_cuda_kernel(assignments, iteration_slice=None, ghost_layers=None, skip_independence_check=False, - use_textures_for_interpolation=True): + use_textures_for_interpolation=True, + do_unify_shape_symbols=True): assert assignments, "Assignments must not be empty!" fields_read, fields_written, assignments = add_types(assignments, type_info, not skip_independence_check) all_fields = fields_read.union(fields_written) @@ -60,7 +61,8 @@ def create_cuda_kernel(assignments, block = Block(assignments) block = indexing.guard(block, common_shape) - unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) + if do_unify_shape_symbols: + unify_shape_symbols(block, common_shape=common_shape, fields=fields_without_buffers) ast = KernelFunction(block, 'gpu', diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index b158754c8716a2d0bdae495ceda59341f945c8d8..ee6cf8a5e1afaf4b512ddf67a2ecf51ccca13ef8 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -30,6 +30,7 @@ def create_kernel(assignments, use_textures_for_interpolation=True, cpu_prepend_optimizations=[], use_auto_for_assignments=False, + do_unify_shape_symbols=True, opencl_queue=None, opencl_ctx=None): """ @@ -119,7 +120,8 @@ def create_kernel(assignments, indexing_creator=indexing_creator_from_params(gpu_indexing, gpu_indexing_params), iteration_slice=iteration_slice, ghost_layers=ghost_layers, skip_independence_check=skip_independence_check, - use_textures_for_interpolation=use_textures_for_interpolation) + use_textures_for_interpolation=use_textures_for_interpolation, + do_unify_shape_symbols=do_unify_shape_symbols) if target == 'opencl': from pystencils.opencl.opencljit import make_python_function ast._backend = 'opencl'