From 4370b2bcb24eb24a9b352fdecb8d9b415c76a200 Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Mon, 27 Jan 2025 16:42:00 +0100 Subject: [PATCH] add compiler info classes for gcc and clang --- src/pystencils/jit/cpu/__init__.py | 3 ++ src/pystencils/jit/cpu/compiler_info.py | 72 +++++++++++++++++++++++++ src/pystencils/jit/cpu/cpu_pybind11.py | 36 +++++-------- tests/jit/test_cpujit.py | 2 +- 4 files changed, 89 insertions(+), 24 deletions(-) create mode 100644 src/pystencils/jit/cpu/compiler_info.py diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py index 69985c914..335ec52d4 100644 --- a/src/pystencils/jit/cpu/__init__.py +++ b/src/pystencils/jit/cpu/__init__.py @@ -1,5 +1,8 @@ +from .compiler_info import GccInfo, ClangInfo from .cpu_pybind11 import CpuJit __all__ = [ + "GccInfo", + "ClangInfo", "CpuJit" ] diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py new file mode 100644 index 000000000..07a5eebbd --- /dev/null +++ b/src/pystencils/jit/cpu/compiler_info.py @@ -0,0 +1,72 @@ +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from ...codegen.target import Target + + +@dataclass +class CompilerInfo(ABC): + openmp: bool = True + + optlevel: str | None = "fast" + + cxx_standard: str = "c++11" + + target: Target = Target.CurrentCPU + + @abstractmethod + def cxx(self) -> str: ... + + @abstractmethod + def cxxflags(self) -> list[str]: ... + + @abstractmethod + def restrict_qualifier(self) -> str: ... + + +class _GnuLikeCliCompiler(CompilerInfo): + def cxxflags(self) -> list[str]: + flags = ["-DNDEBUG", f"-std={self.cxx_standard}"] + + if self.optlevel is not None: + flags.append(f"-O{self.optlevel}") + + if self.openmp: + flags.append("-fopenmp") + + match self.target: + case Target.CurrentCPU: + flags.append("-march=native") + case Target.X86_SSE: + flags += ["-march=x86-64-v2"] + case Target.X86_AVX: + flags += ["-march=x86-64-v3"] + case Target.X86_AVX512: + flags += ["-march=x86-64-v4"] + case Target.X86_AVX512_FP16: + flags += ["-march=x86-64-v4", "-mavx512fp16"] + + return flags + + def restrict_qualifier(self) -> str: + return "__restrict__" + + +class GccInfo(_GnuLikeCliCompiler): + def cxx(self) -> str: + return "g++" + + +@dataclass +class ClangInfo(_GnuLikeCliCompiler): + llvm_version: int | None = None + + def cxx(self) -> str: + if self.llvm_version is None: + return "clang" + else: + return f"clang-{self.llvm_version}" + + def cxxflags(self) -> list[str]: + return super().cxxflags() + ["-lstdc++"] diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py index 128313a74..d0368a390 100644 --- a/src/pystencils/jit/cpu/cpu_pybind11.py +++ b/src/pystencils/jit/cpu/cpu_pybind11.py @@ -5,6 +5,7 @@ from types import ModuleType from pathlib import Path from textwrap import indent import subprocess +from copy import copy from ...types import PsPointerType, PsType from ...field import Field @@ -14,6 +15,7 @@ from ...codegen import Kernel, Parameter from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride from ..jit import JitError, JitBase +from .compiler_info import CompilerInfo, GccInfo _module_template = Path(__file__).parent / "kernel_module.tmpl.cpp" @@ -24,16 +26,12 @@ class CpuJit(JitBase): def __init__( self, - cxx: str = "g++", - cxxflags: Sequence[str] = ("-O3", "-fopenmp"), + compiler_info: CompilerInfo | None = None, objcache: str | Path | None = None, strict_scalar_types: bool = False, ): - self._cxx = cxx - self._cxxflags: list[str] = list(cxxflags) - + self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo() self._strict_scalar_types = strict_scalar_types - self._restrict_qualifier = "__restrict__" if objcache is None: from appdirs import AppDirs @@ -64,20 +62,12 @@ class CpuJit(JitBase): return self._objcache @property - def cxx(self) -> str: - return self._cxx - - @cxx.setter - def cxx(self, path: str): - self._cxx = path - - @property - def cxxflags(self) -> list[str]: - return self._cxxflags - - @cxxflags.setter - def cxxflags(self, flags: Sequence[str]): - self._cxxflags = list(flags) + def compiler_info(self) -> CompilerInfo: + return self._compiler_info + + @compiler_info.setter + def compiler_info(self, info: CompilerInfo): + self._compiler_info = info @property def strict_scalar_types(self) -> bool: @@ -139,9 +129,9 @@ class CpuJit(JitBase): def _compile_extension_module(self, src_file: Path, libfile: Path): args = ( - [self._cxx] + [self._compiler_info.cxx()] + self._cxx_fixed_flags - + self._cxxflags + + self._compiler_info.cxxflags() + ["-o", str(libfile), str(src_file)] ) @@ -195,7 +185,7 @@ class KernelModuleBuilder: templ = Template(_module_template.read_text()) code_str = templ.substitute( includes="\n".join(includes), - restrict_qualifier=self._jit.restrict_qualifier, + restrict_qualifier=self._jit.compiler_info.restrict_qualifier(), module_name=self._module_name, kernel_name=kernel.function_name, param_binds=", ".join(self._param_binds), diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py index 0b09d66b0..aa4f50f3f 100644 --- a/tests/jit/test_cpujit.py +++ b/tests/jit/test_cpujit.py @@ -5,7 +5,7 @@ from pystencils.jit import CpuJit def test_basic_cpu_kernel(tmp_path): - jit = CpuJit("g++", ["-O3"], objcache=tmp_path) + jit = CpuJit(objcache=tmp_path) f, g = fields("f, g: [2D]") asm = Assignment(f.center(), 2.0 * g.center()) -- GitLab