Skip to content
Snippets Groups Projects
Commit 7b08c076 authored by Markus Holzer's avatar Markus Holzer Committed by Frederik Hennig
Browse files

Reveal base pointer spec

parent ea98bc8a
No related branches found
No related tags found
1 merge request!350Reveal base pointer spec
......@@ -2,7 +2,7 @@ from copy import copy
from collections import defaultdict
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict
from typing import Union, Tuple, List, Dict, Callable, Any, DefaultDict, Iterable
from pystencils import Target, Backend, Field
from pystencils.typing.typed_sympy import BasicType
......@@ -78,6 +78,15 @@ class CreateKernelConfig:
"""
If OpenMP is active: whether multiple outer loops are permitted
"""
base_pointer_specification: Union[List[Iterable[str]], List[Iterable[int]]] = None
"""
Specification of how many and which intermediate pointers are created for a field access.
For example [ (0), (2,3,)] creates on base pointer for coordinates 2 and 3 and writes the offset for coordinate
zero directly in the field access. These specifications are defined dependent on the loop ordering.
This function translates more readable version into the specification above.
For more information see: `pystencils.transformations.create_intermediate_base_pointer`
"""
gpu_indexing: str = 'block'
"""
Either 'block' or 'line' , or custom indexing class, see `pystencils.gpu.AbstractIndexing`
......
......@@ -73,7 +73,11 @@ def create_kernel(assignments: NodeCollection,
typed_split_groups = [[type_symbol(s) for s in split_group] for split_group in split_groups]
split_inner_loop(ast_node, typed_split_groups)
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_spec = config.base_pointer_specification
if base_pointer_spec is None:
base_pointer_spec = []
if config.cpu_vectorize_info and config.cpu_vectorize_info.get('nontemporal'):
base_pointer_spec = [['spatialInner0'], ['spatialInner1']] if len(loop_order) >= 2 else [['spatialInner0']]
base_pointer_info = {field.name: parse_base_pointer_info(base_pointer_spec, loop_order,
field.spatial_dimensions, field.index_dimensions)
for field in fields_without_buffers}
......
......@@ -82,7 +82,9 @@ def create_cuda_kernel(assignments: NodeCollection, config: CreateKernelConfig):
assignments=assignments)
ast.global_variables.update(indexing.index_variables)
base_pointer_spec = [['spatialInner0']]
base_pointer_spec = config.base_pointer_specification
if base_pointer_spec is None:
base_pointer_spec = []
base_pointer_info = {f.name: parse_base_pointer_info(base_pointer_spec, [2, 1, 0],
f.spatial_dimensions, f.index_dimensions)
for f in all_fields}
......
import pytest
from pystencils import Assignment, CreateKernelConfig, Target, fields, create_kernel, get_code_str
@pytest.mark.parametrize('target', (Target.CPU, Target.GPU))
def test_intermediate_base_pointer(target):
x = fields(f'x: double[3d]')
y = fields(f'y: double[3d]')
update = Assignment(x.center, y.center)
config = CreateKernelConfig(base_pointer_specification=[], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# no intermediate base pointers are created
assert "_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = " \
"_data_y[_stride_y_0*ctr_0 + _stride_y_1*ctr_1 + _stride_y_2*ctr_2];" in code
config = CreateKernelConfig(base_pointer_specification=[[0]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for y and z
assert "double * RESTRICT _data_x_10_20 = _data_x + _stride_x_1*ctr_1 + _stride_x_2*ctr_2;" in code
assert " double * RESTRICT _data_y_10_20 = _data_y + _stride_y_1*ctr_1 + _stride_y_2*ctr_2;" in code
assert "_data_x_10_20[_stride_x_0*ctr_0] = _data_y_10_20[_stride_y_0*ctr_0];" in code
config = CreateKernelConfig(base_pointer_specification=[[1]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for x and z
assert "double * RESTRICT _data_x_00_20 = _data_x + _stride_x_0*ctr_0 + _stride_x_2*ctr_2;" in code
assert "double * RESTRICT _data_y_00_20 = _data_y + _stride_y_0*ctr_0 + _stride_y_2*ctr_2;" in code
assert "_data_x_00_20[_stride_x_1*ctr_1] = _data_y_00_20[_stride_y_1*ctr_1];" in code
config = CreateKernelConfig(base_pointer_specification=[[2]], target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# intermediate base pointers for x and y
assert "double * RESTRICT _data_x_00_10 = _data_x + _stride_x_0*ctr_0 + _stride_x_1*ctr_1;" in code
assert "double * RESTRICT _data_y_00_10 = _data_y + _stride_y_0*ctr_0 + _stride_y_1*ctr_1;" in code
assert "_data_x_00_10[_stride_x_2*ctr_2] = _data_y_00_10[_stride_y_2*ctr_2];" in code
config = CreateKernelConfig(target=target)
ast = create_kernel(update, config=config)
code = get_code_str(ast)
# by default no intermediate base pointers are created
assert "_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = " \
"_data_y[_stride_y_0*ctr_0 + _stride_y_1*ctr_1 + _stride_y_2*ctr_2];" in code
......@@ -26,6 +26,8 @@ def test_type_interference():
assert 'const uint16_t f' in code
assert 'const int64_t e' in code
assert 'const float d = ((float)(b)) + ((float)(c)) + ((float)(e)) + _data_x_00_10[_stride_x_2*ctr_2];' in code
assert '_data_x_00_10[_stride_x_2*ctr_2] = ((float)(b)) + ((float)(c)) + _data_x_00_10[_stride_x_2*ctr_2];' in code
assert 'const float d = ((float)(b)) + ((float)(c)) + ((float)(e)) + ' \
'_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2];' in code
assert '_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = (' \
'(float)(b)) + ((float)(c)) + _data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2];' in code
assert 'const double g = a + ((double)(b)) + ((double)(d));' in code
......@@ -185,9 +185,11 @@ def test_integer_comparision(dtype):
# There should be an explicit cast for the integer zero to the type of the field on the rhs
if dtype == 'float64':
t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));"
t = "_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = " \
"((((dir) == (1))) ? (0.0): (_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1]));"
else:
t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0f): (_data_f_00[_stride_f_1*ctr_1]));"
t = "_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = " \
"((((dir) == (1))) ? (0.0f): (_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1]));"
assert t in code
......
......@@ -140,6 +140,7 @@ def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True,
'assume_inner_stride_one': True}
update_rule = [ps.Assignment(f.center(), 0.25 * (g[-1, 0] + g[1, 0] + g[0, -1] + g[0, 1]))]
# Without the base pointer spec, the inner store is not aligned
config = pystencils.config.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, cpu_openmp=openmp)
ast = ps.create_kernel(update_rule, config=config)
if instruction_set in ['sse'] or instruction_set.startswith('avx'):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment