From f49635e8c4f5e4ac9c9f4670f1fcf12a09a0f257 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Tue, 12 Sep 2023 11:50:57 +0200 Subject: [PATCH] Fix test cases --- pystencils_tests/test_type_interference.py | 6 ++++-- pystencils_tests/test_types.py | 6 ++++-- pystencils_tests/test_vectorization.py | 4 +++- pystencils_tests/test_vectorization_specific.py | 3 ++- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/pystencils_tests/test_type_interference.py b/pystencils_tests/test_type_interference.py index d240cebc..89529f33 100644 --- a/pystencils_tests/test_type_interference.py +++ b/pystencils_tests/test_type_interference.py @@ -26,6 +26,8 @@ def test_type_interference(): assert 'const uint16_t f' in code assert 'const int64_t e' in code - assert 'const float d = ((float)(b)) + ((float)(c)) + ((float)(e)) + _data_x_00_10[_stride_x_2*ctr_2];' in code - assert '_data_x_00_10[_stride_x_2*ctr_2] = ((float)(b)) + ((float)(c)) + _data_x_00_10[_stride_x_2*ctr_2];' in code + assert 'const float d = ((float)(b)) + ((float)(c)) + ((float)(e)) + ' \ + '_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2];' in code + assert '_data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2] = (' \ + '(float)(b)) + ((float)(c)) + _data_x[_stride_x_0*ctr_0 + _stride_x_1*ctr_1 + _stride_x_2*ctr_2];' in code assert 'const double g = a + ((double)(b)) + ((double)(d));' in code diff --git a/pystencils_tests/test_types.py b/pystencils_tests/test_types.py index 16466df5..2198e51b 100644 --- a/pystencils_tests/test_types.py +++ b/pystencils_tests/test_types.py @@ -185,9 +185,11 @@ def test_integer_comparision(dtype): # There should be an explicit cast for the integer zero to the type of the field on the rhs if dtype == 'float64': - t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0): (_data_f_00[_stride_f_1*ctr_1]));" + t = "_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = " \ + "((((dir) == (1))) ? (0.0): (_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1]));" else: - t = "_data_f_00[_stride_f_1*ctr_1] = ((((dir) == (1))) ? (0.0f): (_data_f_00[_stride_f_1*ctr_1]));" + t = "_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1] = " \ + "((((dir) == (1))) ? (0.0f): (_data_f[_stride_f_0*ctr_0 + _stride_f_1*ctr_1]));" assert t in code diff --git a/pystencils_tests/test_vectorization.py b/pystencils_tests/test_vectorization.py index 3ab4bd39..071bc240 100644 --- a/pystencils_tests/test_vectorization.py +++ b/pystencils_tests/test_vectorization.py @@ -140,7 +140,9 @@ def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set): opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True, 'assume_inner_stride_one': True} update_rule = [ps.Assignment(f.center(), 0.25 * (g[-1, 0] + g[1, 0] + g[0, -1] + g[0, 1]))] - config = pystencils.config.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, cpu_openmp=openmp) + # Without the base pointer spec, the inner store is not aligned + config = pystencils.config.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, cpu_openmp=openmp, + base_pointer_specification=[['spatialInner0']]) ast = ps.create_kernel(update_rule, config=config) if instruction_set in ['sse'] or instruction_set.startswith('avx'): assert 'stream' in ast.instruction_set diff --git a/pystencils_tests/test_vectorization_specific.py b/pystencils_tests/test_vectorization_specific.py index 46e13c2d..610f671e 100644 --- a/pystencils_tests/test_vectorization_specific.py +++ b/pystencils_tests/test_vectorization_specific.py @@ -116,7 +116,8 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set opt = {'instruction_set': instruction_set, 'assume_aligned': True, 'nontemporal': True, 'assume_inner_stride_one': True} config = pystencils.config.CreateKernelConfig(target=dh.default_target, - cpu_vectorize_info=opt, ghost_layers=gl_kernel) + cpu_vectorize_info=opt, ghost_layers=gl_kernel, + base_pointer_specification=[['spatialInner0']]) ast = ps.create_kernel(update_rule, config=config) kernel = ast.compile() if gl_kernel != gl_field: -- GitLab