Skip to content
Snippets Groups Projects

updated kc coupling to support layercondition analysis

Merged Julian Hammer requested to merge hammer/pystencils:kc_lc_support into master
2 files
+ 25
27
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -3,10 +3,13 @@ import fcntl
@@ -3,10 +3,13 @@ import fcntl
from collections import defaultdict
from collections import defaultdict
from tempfile import TemporaryDirectory
from tempfile import TemporaryDirectory
import textwrap
import textwrap
 
import itertools
 
import string
from jinja2 import Environment, PackageLoader, StrictUndefined, Template
from jinja2 import Environment, PackageLoader, StrictUndefined, Template
import sympy as sp
import sympy as sp
from kerncraft.kerncraft import KernelCode
from kerncraft.kerncraft import KernelCode
 
from kerncraft.kernel import symbol_pos_int
from kerncraft.machinemodel import MachineModel
from kerncraft.machinemodel import MachineModel
from pystencils.astnodes import \
from pystencils.astnodes import \
@@ -75,10 +78,6 @@ class PyStencilsKerncraftKernel(KernelCode):
@@ -75,10 +78,6 @@ class PyStencilsKerncraftKernel(KernelCode):
cur_node = cur_node.parent
cur_node = cur_node.parent
self._loop_stack = list(reversed(self._loop_stack))
self._loop_stack = list(reversed(self._loop_stack))
# Data sources & destinations
self.sources = defaultdict(list)
self.destinations = defaultdict(list)
def get_layout_tuple(f):
def get_layout_tuple(f):
if f.has_fixed_shape:
if f.has_fixed_shape:
return get_layout_from_strides(f.strides)
return get_layout_from_strides(f.strides)
@@ -88,23 +87,37 @@ class PyStencilsKerncraftKernel(KernelCode):
@@ -88,23 +87,37 @@ class PyStencilsKerncraftKernel(KernelCode):
layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1)
layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1)
return layout_list
return layout_list
 
# Variables (arrays) and Constants (scalar sizes)
 
const_names_iter = itertools.product(string.ascii_uppercase, repeat=1)
 
constants_reversed = {}
 
fields_accessed = self.kernel_ast.fields_accessed
 
for field in fields_accessed:
 
layout = get_layout_tuple(field)
 
permuted_shape = list(field.shape[i] for i in layout)
 
# Replace shape dimensions with constant variables (necessary for layer condition
 
# analysis)
 
for i, d in enumerate(permuted_shape):
 
if d not in self.constants.values():
 
const_symbol = symbol_pos_int(''.join(next(const_names_iter)))
 
self.set_constant(const_symbol, d)
 
constants_reversed[d] = const_symbol
 
permuted_shape[i] = constants_reversed[d]
 
self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape))
 
 
# Data sources & destinations
 
self.sources = defaultdict(list)
 
self.destinations = defaultdict(list)
 
reads, writes = search_resolved_field_accesses_in_ast(inner_loop)
reads, writes = search_resolved_field_accesses_in_ast(inner_loop)
for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]:
for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]:
for fa in accesses:
for fa in accesses:
coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off
coord = [symbol_pos_int(LoopOverCoordinate.get_loop_counter_name(i)) + off
for i, off in enumerate(fa.offsets)]
for i, off in enumerate(fa.offsets)]
coord += list(fa.idx_coordinate_values)
coord += list(fa.idx_coordinate_values)
layout = get_layout_tuple(fa.field)
layout = get_layout_tuple(fa.field)
permuted_coord = [sp.sympify(coord[i]) for i in layout]
permuted_coord = [sp.sympify(coord[i]) for i in layout]
target_dict[fa.field.name].append(permuted_coord)
target_dict[fa.field.name].append(permuted_coord)
# Variables (arrays)
fields_accessed = self.kernel_ast.fields_accessed
for field in fields_accessed:
layout = get_layout_tuple(field)
permuted_shape = list(field.shape[i] for i in layout)
self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape))
# Scalars may be safely ignored
# Scalars may be safely ignored
# for param in self.kernel_ast.get_parameters():
# for param in self.kernel_ast.get_parameters():
# if not param.is_field_parameter:
# if not param.is_field_parameter:
Loading