Skip to content
Snippets Groups Projects
Commit a1959437 authored by Julian Hammer's avatar Julian Hammer
Browse files

updated kc coupling to support layercondition analysis

parent 74bb2c23
No related merge requests found
...@@ -3,10 +3,13 @@ import fcntl ...@@ -3,10 +3,13 @@ import fcntl
from collections import defaultdict from collections import defaultdict
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
import textwrap import textwrap
import itertools
import string
from jinja2 import Environment, PackageLoader, StrictUndefined, Template from jinja2 import Environment, PackageLoader, StrictUndefined, Template
import sympy as sp import sympy as sp
from kerncraft.kerncraft import KernelCode from kerncraft.kerncraft import KernelCode
from kerncraft.kernel import symbol_pos_int
from kerncraft.machinemodel import MachineModel from kerncraft.machinemodel import MachineModel
from pystencils.astnodes import \ from pystencils.astnodes import \
...@@ -75,10 +78,6 @@ class PyStencilsKerncraftKernel(KernelCode): ...@@ -75,10 +78,6 @@ class PyStencilsKerncraftKernel(KernelCode):
cur_node = cur_node.parent cur_node = cur_node.parent
self._loop_stack = list(reversed(self._loop_stack)) self._loop_stack = list(reversed(self._loop_stack))
# Data sources & destinations
self.sources = defaultdict(list)
self.destinations = defaultdict(list)
def get_layout_tuple(f): def get_layout_tuple(f):
if f.has_fixed_shape: if f.has_fixed_shape:
return get_layout_from_strides(f.strides) return get_layout_from_strides(f.strides)
...@@ -88,23 +87,37 @@ class PyStencilsKerncraftKernel(KernelCode): ...@@ -88,23 +87,37 @@ class PyStencilsKerncraftKernel(KernelCode):
layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1) layout_list.insert(0 if assumed_layout == 'SoA' else -1, max(layout_list) + 1)
return layout_list return layout_list
# Variables (arrays) and Constants (scalar sizes)
const_names_iter = itertools.product(string.ascii_uppercase, repeat=1)
constants_reversed = {}
fields_accessed = self.kernel_ast.fields_accessed
for field in fields_accessed:
layout = get_layout_tuple(field)
permuted_shape = list(field.shape[i] for i in layout)
# Replace shape dimensions with constant variables (necessary for layer condition
# analysis)
for i, d in enumerate(permuted_shape):
if d not in self.constants.values():
const_symbol = symbol_pos_int(''.join(next(const_names_iter)))
self.set_constant(const_symbol, d)
constants_reversed[d] = const_symbol
permuted_shape[i] = constants_reversed[d]
self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape))
# Data sources & destinations
self.sources = defaultdict(list)
self.destinations = defaultdict(list)
reads, writes = search_resolved_field_accesses_in_ast(inner_loop) reads, writes = search_resolved_field_accesses_in_ast(inner_loop)
for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]: for accesses, target_dict in [(reads, self.sources), (writes, self.destinations)]:
for fa in accesses: for fa in accesses:
coord = [sp.Symbol(LoopOverCoordinate.get_loop_counter_name(i), positive=True, integer=True) + off coord = [symbol_pos_int(LoopOverCoordinate.get_loop_counter_name(i)) + off
for i, off in enumerate(fa.offsets)] for i, off in enumerate(fa.offsets)]
coord += list(fa.idx_coordinate_values) coord += list(fa.idx_coordinate_values)
layout = get_layout_tuple(fa.field) layout = get_layout_tuple(fa.field)
permuted_coord = [sp.sympify(coord[i]) for i in layout] permuted_coord = [sp.sympify(coord[i]) for i in layout]
target_dict[fa.field.name].append(permuted_coord) target_dict[fa.field.name].append(permuted_coord)
# Variables (arrays)
fields_accessed = self.kernel_ast.fields_accessed
for field in fields_accessed:
layout = get_layout_tuple(field)
permuted_shape = list(field.shape[i] for i in layout)
self.set_variable(field.name, (str(field.dtype),), tuple(permuted_shape))
# Scalars may be safely ignored # Scalars may be safely ignored
# for param in self.kernel_ast.get_parameters(): # for param in self.kernel_ast.get_parameters():
# if not param.is_field_parameter: # if not param.is_field_parameter:
......
...@@ -165,18 +165,3 @@ def test_benchmark(): ...@@ -165,18 +165,3 @@ def test_benchmark():
timeloop_time = timeloop.benchmark(number_of_time_steps_for_estimation=1) timeloop_time = timeloop.benchmark(number_of_time_steps_for_estimation=1)
np.testing.assert_almost_equal(c_benchmark_run, timeloop_time, decimal=4) np.testing.assert_almost_equal(c_benchmark_run, timeloop_time, decimal=4)
@pytest.mark.kerncraft
def test_kerncraft_generic_field():
machine_file_path = INPUT_FOLDER / "Example_SandyBridgeEP_E5-2680.yml"
machine = MachineModel(path_to_yaml=machine_file_path)
a = fields('a: double[3D]')
b = fields('b: double[3D]')
s = sp.Symbol("s")
rhs = a[0, -1, 0] + a[0, 1, 0] + a[-1, 0, 0] + a[1, 0, 0] + a[0, 0, -1] + a[0, 0, 1]
update_rule = Assignment(b[0, 0, 0], s * rhs)
ast = create_kernel([update_rule])
k = PyStencilsKerncraftKernel(ast, machine, debug_print=True)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment