From fe3cd6cd19f2c11166e2b8f670a0a8a337b898ef Mon Sep 17 00:00:00 2001 From: zy69guqi <richard.angersbach@fau.de> Date: Mon, 17 Feb 2025 16:02:05 +0100 Subject: [PATCH] Add generator for SIMD horizontal operations and the emitted code. --- .../include/simd_horizontal_helpers.h | 232 ++++++++++++- util/generate_simd_horizontal_op.py | 309 ++++++++++++++++++ 2 files changed, 535 insertions(+), 6 deletions(-) create mode 100644 util/generate_simd_horizontal_op.py diff --git a/src/pystencils/include/simd_horizontal_helpers.h b/src/pystencils/include/simd_horizontal_helpers.h index 6a80f2107..cd4bd5730 100644 --- a/src/pystencils/include/simd_horizontal_helpers.h +++ b/src/pystencils/include/simd_horizontal_helpers.h @@ -1,11 +1,231 @@ #pragma once +#include <cmath> + +#if defined(__SSE3__) +#include <immintrin.h> + +inline double _mm_horizontal_add_pd(double dst, __m128d src) { + __m128d _v = src; + return dst + _mm_cvtsd_f64(_mm_hadd_pd(_v, _v)); +} + +inline float _mm_horizontal_add_ps(float dst, __m128 src) { + __m128 _v = src; + __m128 _h = _mm_hadd_ps(_v, _v); + return dst + _mm_cvtss_f32(_mm_add_ps(_h, _mm_movehdup_ps(_h))); +} + +inline double _mm_horizontal_mul_pd(double dst, __m128d src) { + __m128d _v = src; + double _r = _mm_cvtsd_f64(_mm_mul_pd(_v, _mm_shuffle_pd(_v, _v, 1))); + return dst * _r; +} + +inline float _mm_horizontal_mul_ps(float dst, __m128 src) { + __m128 _v = src; + __m128 _h = _mm_mul_ps(_v, _mm_shuffle_ps(_v, _v, 177)); + float _r = _mm_cvtss_f32(_mm_mul_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return dst * _r; +} + +inline double _mm_horizontal_min_pd(double dst, __m128d src) { + __m128d _v = src; + double _r = _mm_cvtsd_f64(_mm_min_pd(_v, _mm_shuffle_pd(_v, _v, 1))); + return fmin(_r, dst); +} + +inline float _mm_horizontal_min_ps(float dst, __m128 src) { + __m128 _v = src; + __m128 _h = _mm_min_ps(_v, _mm_shuffle_ps(_v, _v, 177)); + float _r = _mm_cvtss_f32(_mm_min_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return fmin(_r, dst); +} + +inline double _mm_horizontal_max_pd(double dst, __m128d src) { + __m128d _v = src; + double _r = _mm_cvtsd_f64(_mm_max_pd(_v, _mm_shuffle_pd(_v, _v, 1))); + return fmax(_r, dst); +} + +inline float _mm_horizontal_max_ps(float dst, __m128 src) { + __m128 _v = src; + __m128 _h = _mm_max_ps(_v, _mm_shuffle_ps(_v, _v, 177)); + float _r = _mm_cvtss_f32(_mm_max_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return fmax(_r, dst); +} + +#endif + +#if defined(__AVX__) +#include <immintrin.h> + +inline double _mm256_horizontal_add_pd(double dst, __m256d src) { + __m256d _v = src; + __m256d _h = _mm256_hadd_pd(_v, _v); + return dst + _mm_cvtsd_f64(_mm_add_pd(_mm256_extractf128_pd(_h,1), _mm256_castpd256_pd128(_h))); +} + +inline float _mm256_horizontal_add_ps(float dst, __m256 src) { + __m256 _v = src; + __m256 _h = _mm256_hadd_ps(_v, _v); + __m128 _i = _mm_add_ps(_mm256_extractf128_ps(_h,1), _mm256_castps256_ps128(_h)); + return dst + _mm_cvtss_f32(_mm_hadd_ps(_i,_i)); +} + +inline double _mm256_horizontal_mul_pd(double dst, __m256d src) { + __m256d _v = src; + __m128d _w = _mm_mul_pd(_mm256_extractf128_pd(_v,1), _mm256_castpd256_pd128(_v)); + double _r = _mm_cvtsd_f64(_mm_mul_pd(_w, _mm_permute_pd(_w,1))); + return dst * _r; +} + +inline float _mm256_horizontal_mul_ps(float dst, __m256 src) { + __m256 _v = src; + __m128 _w = _mm_mul_ps(_mm256_extractf128_ps(_v,1), _mm256_castps256_ps128(_v)); + __m128 _h = _mm_mul_ps(_w, _mm_shuffle_ps(_w, _w, 177)); + float _r = _mm_cvtss_f32(_mm_mul_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return dst * _r; +} + +inline double _mm256_horizontal_min_pd(double dst, __m256d src) { + __m256d _v = src; + __m128d _w = _mm_min_pd(_mm256_extractf128_pd(_v,1), _mm256_castpd256_pd128(_v)); + double _r = _mm_cvtsd_f64(_mm_min_pd(_w, _mm_permute_pd(_w,1))); + return fmin(_r, dst); +} + +inline float _mm256_horizontal_min_ps(float dst, __m256 src) { + __m256 _v = src; + __m128 _w = _mm_min_ps(_mm256_extractf128_ps(_v,1), _mm256_castps256_ps128(_v)); + __m128 _h = _mm_min_ps(_w, _mm_shuffle_ps(_w, _w, 177)); + float _r = _mm_cvtss_f32(_mm_min_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return fmin(_r, dst); +} + +inline double _mm256_horizontal_max_pd(double dst, __m256d src) { + __m256d _v = src; + __m128d _w = _mm_max_pd(_mm256_extractf128_pd(_v,1), _mm256_castpd256_pd128(_v)); + double _r = _mm_cvtsd_f64(_mm_max_pd(_w, _mm_permute_pd(_w,1))); + return fmax(_r, dst); +} + +inline float _mm256_horizontal_max_ps(float dst, __m256 src) { + __m256 _v = src; + __m128 _w = _mm_max_ps(_mm256_extractf128_ps(_v,1), _mm256_castps256_ps128(_v)); + __m128 _h = _mm_max_ps(_w, _mm_shuffle_ps(_w, _w, 177)); + float _r = _mm_cvtss_f32(_mm_max_ps(_h, _mm_shuffle_ps(_h, _h, 10))); + return fmax(_r, dst); +} + +#endif + +#if defined(__AVX512VL__) #include <immintrin.h> -#define QUALIFIERS inline +inline double _mm512_horizontal_add_pd(double dst, __m512d src) { + double _r = _mm512_reduce_add_pd(src); + return dst + _r; +} + +inline float _mm512_horizontal_add_ps(float dst, __m512 src) { + float _r = _mm512_reduce_add_ps(src); + return dst + _r; +} + +inline double _mm512_horizontal_mul_pd(double dst, __m512d src) { + double _r = _mm512_reduce_mul_pd(src); + return dst * _r; +} + +inline float _mm512_horizontal_mul_ps(float dst, __m512 src) { + float _r = _mm512_reduce_mul_ps(src); + return dst * _r; +} + +inline double _mm512_horizontal_min_pd(double dst, __m512d src) { + double _r = _mm512_reduce_min_pd(src); + return fmin(_r, dst); +} + +inline float _mm512_horizontal_min_ps(float dst, __m512 src) { + float _r = _mm512_reduce_min_ps(src); + return fmin(_r, dst); +} + +inline double _mm512_horizontal_max_pd(double dst, __m512d src) { + double _r = _mm512_reduce_max_pd(src); + return fmax(_r, dst); +} + +inline float _mm512_horizontal_max_ps(float dst, __m512 src) { + float _r = _mm512_reduce_max_ps(src); + return fmax(_r, dst); +} + +#endif + +#if defined(_M_ARM64) +#include <arm_neon.h> + +inline double vgetq_horizontal_add_f64(double dst, float64x2_t src) { + float64x2_t _v = src; + double _r = vgetq_lane_f64(_v,0); + _r += vgetq_lane_f64(_v,1); + return dst + _r; +} + +inline float vget_horizontal_add_f32(float dst, float32x4_t src) { + float32x4_t _v = src; + float32x2_t _w = vadd_f32(vget_high_f32(_v), vget_low_f32(_v)); + float _r = vgetq_lane_f32(_w,0); + _r += vget_lane_f32(_w,1); + return dst + _r; +} + +inline double vgetq_horizontal_mul_f64(double dst, float64x2_t src) { + float64x2_t _v = src; + double _r = vgetq_lane_f64(_v,0); + _r *= vgetq_lane_f64(_v,1); + return dst * _r; +} + +inline float vget_horizontal_mul_f32(float dst, float32x4_t src) { + float32x4_t _v = src; + float32x2_t _w = vmul_f32(vget_high_f32(_v), vget_low_f32(_v)); + float _r = vgetq_lane_f32(_w,0); + _r *= vget_lane_f32(_w,1); + return dst * _r; +} + +inline double vgetq_horizontal_min_f64(double dst, float64x2_t src) { + float64x2_t _v = src; + double _r = vgetq_lane_f64(_v,0); + _r = fmin(_r, vgetq_lane_f64(_v,1)); + return fmin(_r, dst); +} + +inline float vget_horizontal_min_f32(float dst, float32x4_t src) { + float32x4_t _v = src; + float32x2_t _w = vmin_f32(vget_high_f32(_v), vget_low_f32(_v)); + float _r = vgetq_lane_f32(_w,0); + _r = fmin(_r, vget_lane_f32(_w,1)); + return fmin(_r, dst); +} + +inline double vgetq_horizontal_max_f64(double dst, float64x2_t src) { + float64x2_t _v = src; + double _r = vgetq_lane_f64(_v,0); + _r = fmax(_r, vgetq_lane_f64(_v,1)); + return fmax(_r, dst); +} + +inline float vget_horizontal_max_f32(float dst, float32x4_t src) { + float32x4_t _v = src; + float32x2_t _w = vmax_f32(vget_high_f32(_v), vget_low_f32(_v)); + float _r = vgetq_lane_f32(_w,0); + _r = fmax(_r, vget_lane_f32(_w,1)); + return fmax(_r, dst); +} -QUALIFIERS double _mm256_horizontal_add_pd(double a, __m256d b) { - __m256d _v = b; - __m256d _h = _mm256_hadd_pd(_v,_v); - return a + _mm_cvtsd_f64(_mm_add_pd(_mm256_extractf128_pd(_h,1), _mm256_castpd256_pd128(_h))); -} \ No newline at end of file +#endif \ No newline at end of file diff --git a/util/generate_simd_horizontal_op.py b/util/generate_simd_horizontal_op.py new file mode 100644 index 000000000..aebbf35bb --- /dev/null +++ b/util/generate_simd_horizontal_op.py @@ -0,0 +1,309 @@ +from enum import Enum + +FCT_QUALIFIERS = "inline" + + +class InstructionSets(Enum): + SSE3 = "SSE3" + AVX = "AVX" + AVX512 = "AVX512" + NEON = "NEON" + + def __str__(self): + return self.value + + +class ReductionOps(Enum): + Add = ("add", "+") + Mul = ("mul", "*") + Min = ("min", "min") + Max = ("max", "max") + + def __init__(self, op_name, op_str): + self.op_name = op_name + self.op_str = op_str + + +class ScalarTypes(Enum): + Double = "double" + Float = "float" + + def __str__(self): + return self.value + + +class VectorTypes(Enum): + SSE3_128d = "__m128d" + SSE3_128 = "__m128" + + AVX_256d = "__m256d" + AVX_256 = "__m256" + AVX_128 = "__m128" + + AVX_512d = "__m512d" + AVX_512 = "__m512" + + NEON_64x2 = "float64x2_t" + NEON_32x4 = "float32x4_t" + + def __str__(self): + return self.value + + +class Variable: + def __init__(self, name: str, dtype: ScalarTypes | VectorTypes): + self._name = name + self._dtype = dtype + + def __str__(self): + return f"{self._dtype} {self._name}" + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> ScalarTypes | VectorTypes: + return self._dtype + + +def get_intrin_from_vector_type(vtype: VectorTypes) -> InstructionSets: + match vtype: + case VectorTypes.SSE3_128 | VectorTypes.SSE3_128d: + return InstructionSets.SSE3 + case VectorTypes.AVX_256 | VectorTypes.AVX_256d: + return InstructionSets.AVX + case VectorTypes.AVX_512 | VectorTypes.AVX_512d: + return InstructionSets.AVX512 + case VectorTypes.NEON_32x4 | VectorTypes.NEON_64x2: + return InstructionSets.NEON + + +def intrin_prefix(instruction_set: InstructionSets, double_prec: bool): + match instruction_set: + case InstructionSets.SSE3: + return "_mm" + case InstructionSets.AVX: + return "_mm256" + case InstructionSets.AVX512: + return "_mm512" + case InstructionSets.NEON: + return "vgetq" if double_prec else "vget" + case _: + raise ValueError(f"Unknown instruction set {instruction_set}") + + +def intrin_suffix(instruction_set: InstructionSets, double_prec: bool): + if instruction_set in [InstructionSets.SSE3, InstructionSets.AVX, InstructionSets.AVX512]: + return "pd" if double_prec else "ps" + elif instruction_set in [InstructionSets.NEON]: + return "f64" if double_prec else "f32" + else: + raise ValueError(f"Unknown instruction set {instruction_set}") + + +def generate_hadd_intrin(instruction_set: InstructionSets, double_prec: bool, v: str): + return f"{intrin_prefix(instruction_set, double_prec)}_hadd_{intrin_suffix(instruction_set, double_prec)}({v}, {v})" + + +def generate_shuffle_intrin(instruction_set: InstructionSets, double_prec: bool, v: str, offset): + return f"_mm_shuffle_{intrin_suffix(instruction_set, double_prec)}({v}, {v}, {offset})" + + +def generate_op_intrin(instruction_set: InstructionSets, double_prec: bool, reduction_op: ReductionOps, a: str, b: str): + return f"_mm_{reduction_op.op_name}_{intrin_suffix(instruction_set, double_prec)}({a}, {b})" + + +def generate_cvts_intrin(double_prec: bool, v: str): + convert_suffix = "f64" if double_prec else "f32" + intrin_suffix = "d" if double_prec else "s" + return f"_mm_cvts{intrin_suffix}_{convert_suffix}({v})" + + +def generate_fct_name(instruction_set: InstructionSets, double_prec: bool, op: ReductionOps): + prefix = intrin_prefix(instruction_set, double_prec) + suffix = intrin_suffix(instruction_set, double_prec) + return f"{prefix}_horizontal_{op.op_name}_{suffix}" + + +def generate_fct_decl(instruction_set: InstructionSets, op: ReductionOps, svar: Variable, vvar: Variable): + double_prec = svar.dtype is ScalarTypes.Double + return f"{FCT_QUALIFIERS} {svar.dtype} {generate_fct_name(instruction_set, double_prec, op)}({svar}, {vvar}) {{ \n" + + +# SSE & AVX provide horizontal add 'hadd' intrinsic that allows for specialized handling +def generate_simd_horizontal_add(scalar_var: Variable, vector_var: Variable): + reduction_op = ReductionOps.Add + instruction_set = get_intrin_from_vector_type(vector_var.dtype) + double_prec = scalar_var.dtype is ScalarTypes.Double + + sname = scalar_var.name + vtype = vector_var.dtype + vname = vector_var.name + + simd_op = lambda a, b: generate_op_intrin(instruction_set, double_prec, reduction_op, a, b) + hadd = lambda var: generate_hadd_intrin(instruction_set, double_prec, var) + cvts = lambda var: generate_cvts_intrin(double_prec, var) + + # function body + body = f"\t{vtype} _v = {vname};\n" + match instruction_set: + case InstructionSets.SSE3: + if double_prec: + body += f"\treturn {sname} + {cvts(hadd('_v'))};\n" + else: + body += f"\t{vtype} _h = {hadd('_v')};\n" \ + f"\treturn {sname} + {cvts(simd_op('_h', '_mm_movehdup_ps(_h)'))};\n" + + case InstructionSets.AVX: + if double_prec: + body += f"\t{vtype} _h = {hadd('_v')};\n" \ + f"\treturn {sname} + {cvts(simd_op('_mm256_extractf128_pd(_h,1)', '_mm256_castpd256_pd128(_h)'))};\n" + else: + add_i = "_mm_hadd_ps(_i,_i)" + body += f"\t{vtype} _h = {hadd('_v')};\n" \ + f"\t__m128 _i = {simd_op('_mm256_extractf128_ps(_h,1)', '_mm256_castps256_ps128(_h)')};\n" \ + f"\treturn {sname} + {cvts(add_i)};\n" + + case _: + raise ValueError(f"No specialized version of horizontal_add available for {instruction_set}") + + # function decl + decl = generate_fct_decl(instruction_set, reduction_op, scalar_var, vector_var) + + return decl + body + "}\n" + + +def generate_simd_horizontal_op(reduction_op: ReductionOps, scalar_var: Variable, vector_var: Variable): + instruction_set = get_intrin_from_vector_type(vector_var.dtype) + double_prec = scalar_var.dtype is ScalarTypes.Double + + # generate specialized version for add operation + if reduction_op == ReductionOps.Add and instruction_set in [InstructionSets.SSE3, InstructionSets.AVX]: + return generate_simd_horizontal_add(scalar_var, vector_var) + + sname = scalar_var.name + stype = scalar_var.dtype + vtype = vector_var.dtype + vname = vector_var.name + + opname = reduction_op.op_name + opstr = reduction_op.op_str + + reduction_function = f"f{opname}" \ + if reduction_op in [ReductionOps.Max, ReductionOps.Min] else None + + simd_op = lambda a, b: generate_op_intrin(instruction_set, double_prec, reduction_op, a, b) + cvts = lambda var: generate_cvts_intrin(double_prec, var) + shuffle = lambda var, offset: generate_shuffle_intrin(instruction_set, double_prec, var, offset) + + # function body + body = f"\t{vtype} _v = {vname};\n" if instruction_set != InstructionSets.AVX512 else "" + match instruction_set: + case InstructionSets.SSE3: + if double_prec: + body += f"\t{stype} _r = {cvts(simd_op('_v', shuffle('_v', 1)))};\n" + else: + body += f"\t{vtype} _h = {simd_op('_v', shuffle('_v', 177))};\n" \ + f"\t{stype} _r = {cvts(simd_op('_h', shuffle('_h', 10)))};\n" + + case InstructionSets.AVX: + if double_prec: + body += f"\t__m128d _w = {simd_op('_mm256_extractf128_pd(_v,1)', '_mm256_castpd256_pd128(_v)')};\n" \ + f"\t{stype} _r = {cvts(simd_op('_w', '_mm_permute_pd(_w,1)'))}; \n" + else: + body += f"\t__m128 _w = {simd_op('_mm256_extractf128_ps(_v,1)', '_mm256_castps256_ps128(_v)')};\n" \ + f"\t__m128 _h = {simd_op('_w', shuffle('_w', 177))};\n" \ + f"\t{stype} _r = {cvts(simd_op('_h', shuffle('_h', 10)))};\n" + + case InstructionSets.AVX512: + suffix = intrin_suffix(instruction_set, double_prec) + body += f"\t{stype} _r = _mm512_reduce_{opname}_{suffix}({vname});\n" + + case InstructionSets.NEON: + if double_prec: + body += f"\t{stype} _r = vgetq_lane_f64(_v,0);\n" + if reduction_function: + body += f"\t_r = {reduction_function}(_r, vgetq_lane_f64(_v,1));\n" + else: + body += f"\t_r {opstr}= vgetq_lane_f64(_v,1);\n" + else: + body += f"\tfloat32x2_t _w = v{opname}_f32(vget_high_f32(_v), vget_low_f32(_v));\n" \ + f"\t{stype} _r = vgetq_lane_f32(_w,0);\n" + if reduction_function: + body += f"\t_r = {reduction_function}(_r, vget_lane_f32(_w,1));\n" + else: + body += f"\t_r {opstr}= vget_lane_f32(_w,1);\n" + + case _: + raise ValueError(f"Unsupported instruction set {instruction_set}") + + # finalize reduction + if reduction_function: + body += f"\treturn {reduction_function}(_r, {sname});\n" + else: + body += f"\treturn {sname} {opstr} _r;\n" + + # function decl + decl = generate_fct_decl(instruction_set, reduction_op, scalar_var, vector_var) + + return decl + body + "}\n" + + +stypes = { + True: ScalarTypes.Double, + False: ScalarTypes.Float +} + +vtypes_for_instruction_set = { + InstructionSets.SSE3: { + True: VectorTypes.SSE3_128d, + False: VectorTypes.SSE3_128 + }, + InstructionSets.AVX: { + True: VectorTypes.AVX_256d, + False: VectorTypes.AVX_256 + }, + InstructionSets.AVX512: { + True: VectorTypes.AVX_512d, + False: VectorTypes.AVX_512 + }, + InstructionSets.NEON: { + True: VectorTypes.NEON_64x2, + False: VectorTypes.NEON_32x4 + }, +} + +guards_for_instruction_sets = { + InstructionSets.SSE3: "__SSE3__", + InstructionSets.AVX: "__AVX__", + InstructionSets.AVX512: '__AVX512VL__', + InstructionSets.NEON: '_M_ARM64', +} + +code = """#pragma once + +#include <cmath> + +""" + +for instruction_set in InstructionSets: + code += f"#if defined({guards_for_instruction_sets[instruction_set]})\n" + + if instruction_set in [InstructionSets.SSE3, InstructionSets.AVX, InstructionSets.AVX512]: + code += "#include <immintrin.h>\n\n" + elif instruction_set == InstructionSets.NEON: + code += "#include <arm_neon.h>\n\n" + else: + ValueError(f"Missing header include for instruction set {instruction_set}") + + for reduction_op in ReductionOps: + for double_prec in [True, False]: + scalar_var = Variable("dst", stypes[double_prec]) + vector_var = Variable("src", vtypes_for_instruction_set[instruction_set][double_prec]) + + code += generate_simd_horizontal_op(reduction_op, scalar_var, vector_var) + "\n" + + code += "#endif\n\n" + +print(code) -- GitLab