From 4370b2bcb24eb24a9b352fdecb8d9b415c76a200 Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Mon, 27 Jan 2025 16:42:00 +0100
Subject: [PATCH] add compiler info classes for gcc and clang

---
 src/pystencils/jit/cpu/__init__.py      |  3 ++
 src/pystencils/jit/cpu/compiler_info.py | 72 +++++++++++++++++++++++++
 src/pystencils/jit/cpu/cpu_pybind11.py  | 36 +++++--------
 tests/jit/test_cpujit.py                |  2 +-
 4 files changed, 89 insertions(+), 24 deletions(-)
 create mode 100644 src/pystencils/jit/cpu/compiler_info.py

diff --git a/src/pystencils/jit/cpu/__init__.py b/src/pystencils/jit/cpu/__init__.py
index 69985c914..335ec52d4 100644
--- a/src/pystencils/jit/cpu/__init__.py
+++ b/src/pystencils/jit/cpu/__init__.py
@@ -1,5 +1,8 @@
+from .compiler_info import GccInfo, ClangInfo
 from .cpu_pybind11 import CpuJit
 
 __all__ = [
+    "GccInfo",
+    "ClangInfo",
     "CpuJit"
 ]
diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py
new file mode 100644
index 000000000..07a5eebbd
--- /dev/null
+++ b/src/pystencils/jit/cpu/compiler_info.py
@@ -0,0 +1,72 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from ...codegen.target import Target
+
+
+@dataclass
+class CompilerInfo(ABC):
+    openmp: bool = True
+
+    optlevel: str | None = "fast"
+
+    cxx_standard: str = "c++11"
+
+    target: Target = Target.CurrentCPU
+
+    @abstractmethod
+    def cxx(self) -> str: ...
+
+    @abstractmethod
+    def cxxflags(self) -> list[str]: ...
+
+    @abstractmethod
+    def restrict_qualifier(self) -> str: ...
+
+
+class _GnuLikeCliCompiler(CompilerInfo):
+    def cxxflags(self) -> list[str]:
+        flags = ["-DNDEBUG", f"-std={self.cxx_standard}"]
+
+        if self.optlevel is not None:
+            flags.append(f"-O{self.optlevel}")
+
+        if self.openmp:
+            flags.append("-fopenmp")
+
+        match self.target:
+            case Target.CurrentCPU:
+                flags.append("-march=native")
+            case Target.X86_SSE:
+                flags += ["-march=x86-64-v2"]
+            case Target.X86_AVX:
+                flags += ["-march=x86-64-v3"]
+            case Target.X86_AVX512:
+                flags += ["-march=x86-64-v4"]
+            case Target.X86_AVX512_FP16:
+                flags += ["-march=x86-64-v4", "-mavx512fp16"]
+
+        return flags
+
+    def restrict_qualifier(self) -> str:
+        return "__restrict__"
+
+
+class GccInfo(_GnuLikeCliCompiler):
+    def cxx(self) -> str:
+        return "g++"
+
+
+@dataclass
+class ClangInfo(_GnuLikeCliCompiler):
+    llvm_version: int | None = None
+
+    def cxx(self) -> str:
+        if self.llvm_version is None:
+            return "clang"
+        else:
+            return f"clang-{self.llvm_version}"
+        
+    def cxxflags(self) -> list[str]:
+        return super().cxxflags() + ["-lstdc++"]
diff --git a/src/pystencils/jit/cpu/cpu_pybind11.py b/src/pystencils/jit/cpu/cpu_pybind11.py
index 128313a74..d0368a390 100644
--- a/src/pystencils/jit/cpu/cpu_pybind11.py
+++ b/src/pystencils/jit/cpu/cpu_pybind11.py
@@ -5,6 +5,7 @@ from types import ModuleType
 from pathlib import Path
 from textwrap import indent
 import subprocess
+from copy import copy
 
 from ...types import PsPointerType, PsType
 from ...field import Field
@@ -14,6 +15,7 @@ from ...codegen import Kernel, Parameter
 from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
 
 from ..jit import JitError, JitBase
+from .compiler_info import CompilerInfo, GccInfo
 
 
 _module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
@@ -24,16 +26,12 @@ class CpuJit(JitBase):
 
     def __init__(
         self,
-        cxx: str = "g++",
-        cxxflags: Sequence[str] = ("-O3", "-fopenmp"),
+        compiler_info: CompilerInfo | None = None,
         objcache: str | Path | None = None,
         strict_scalar_types: bool = False,
     ):
-        self._cxx = cxx
-        self._cxxflags: list[str] = list(cxxflags)
-
+        self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo()
         self._strict_scalar_types = strict_scalar_types
-        self._restrict_qualifier = "__restrict__"
 
         if objcache is None:
             from appdirs import AppDirs
@@ -64,20 +62,12 @@ class CpuJit(JitBase):
         return self._objcache
 
     @property
-    def cxx(self) -> str:
-        return self._cxx
-
-    @cxx.setter
-    def cxx(self, path: str):
-        self._cxx = path
-
-    @property
-    def cxxflags(self) -> list[str]:
-        return self._cxxflags
-
-    @cxxflags.setter
-    def cxxflags(self, flags: Sequence[str]):
-        self._cxxflags = list(flags)
+    def compiler_info(self) -> CompilerInfo:
+        return self._compiler_info
+    
+    @compiler_info.setter
+    def compiler_info(self, info: CompilerInfo):
+        self._compiler_info = info
 
     @property
     def strict_scalar_types(self) -> bool:
@@ -139,9 +129,9 @@ class CpuJit(JitBase):
 
     def _compile_extension_module(self, src_file: Path, libfile: Path):
         args = (
-            [self._cxx]
+            [self._compiler_info.cxx()]
             + self._cxx_fixed_flags
-            + self._cxxflags
+            + self._compiler_info.cxxflags()
             + ["-o", str(libfile), str(src_file)]
         )
 
@@ -195,7 +185,7 @@ class KernelModuleBuilder:
         templ = Template(_module_template.read_text())
         code_str = templ.substitute(
             includes="\n".join(includes),
-            restrict_qualifier=self._jit.restrict_qualifier,
+            restrict_qualifier=self._jit.compiler_info.restrict_qualifier(),
             module_name=self._module_name,
             kernel_name=kernel.function_name,
             param_binds=", ".join(self._param_binds),
diff --git a/tests/jit/test_cpujit.py b/tests/jit/test_cpujit.py
index 0b09d66b0..aa4f50f3f 100644
--- a/tests/jit/test_cpujit.py
+++ b/tests/jit/test_cpujit.py
@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
 
 
 def test_basic_cpu_kernel(tmp_path):
-    jit = CpuJit("g++", ["-O3"], objcache=tmp_path)
+    jit = CpuJit(objcache=tmp_path)
 
     f, g = fields("f, g: [2D]")
     asm = Assignment(f.center(), 2.0 * g.center())
-- 
GitLab