Skip to content
Snippets Groups Projects
Commit 4370b2bc authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add compiler info classes for gcc and clang

parent 6efa0439
No related branches found
No related tags found
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
Pipeline #72859 failed
from .compiler_info import GccInfo, ClangInfo
from .cpu_pybind11 import CpuJit
__all__ = [
"GccInfo",
"ClangInfo",
"CpuJit"
]
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from ...codegen.target import Target
@dataclass
class CompilerInfo(ABC):
openmp: bool = True
optlevel: str | None = "fast"
cxx_standard: str = "c++11"
target: Target = Target.CurrentCPU
@abstractmethod
def cxx(self) -> str: ...
@abstractmethod
def cxxflags(self) -> list[str]: ...
@abstractmethod
def restrict_qualifier(self) -> str: ...
class _GnuLikeCliCompiler(CompilerInfo):
def cxxflags(self) -> list[str]:
flags = ["-DNDEBUG", f"-std={self.cxx_standard}"]
if self.optlevel is not None:
flags.append(f"-O{self.optlevel}")
if self.openmp:
flags.append("-fopenmp")
match self.target:
case Target.CurrentCPU:
flags.append("-march=native")
case Target.X86_SSE:
flags += ["-march=x86-64-v2"]
case Target.X86_AVX:
flags += ["-march=x86-64-v3"]
case Target.X86_AVX512:
flags += ["-march=x86-64-v4"]
case Target.X86_AVX512_FP16:
flags += ["-march=x86-64-v4", "-mavx512fp16"]
return flags
def restrict_qualifier(self) -> str:
return "__restrict__"
class GccInfo(_GnuLikeCliCompiler):
def cxx(self) -> str:
return "g++"
@dataclass
class ClangInfo(_GnuLikeCliCompiler):
llvm_version: int | None = None
def cxx(self) -> str:
if self.llvm_version is None:
return "clang"
else:
return f"clang-{self.llvm_version}"
def cxxflags(self) -> list[str]:
return super().cxxflags() + ["-lstdc++"]
......@@ -5,6 +5,7 @@ from types import ModuleType
from pathlib import Path
from textwrap import indent
import subprocess
from copy import copy
from ...types import PsPointerType, PsType
from ...field import Field
......@@ -14,6 +15,7 @@ from ...codegen import Kernel, Parameter
from ...codegen.properties import FieldBasePtr, FieldShape, FieldStride
from ..jit import JitError, JitBase
from .compiler_info import CompilerInfo, GccInfo
_module_template = Path(__file__).parent / "kernel_module.tmpl.cpp"
......@@ -24,16 +26,12 @@ class CpuJit(JitBase):
def __init__(
self,
cxx: str = "g++",
cxxflags: Sequence[str] = ("-O3", "-fopenmp"),
compiler_info: CompilerInfo | None = None,
objcache: str | Path | None = None,
strict_scalar_types: bool = False,
):
self._cxx = cxx
self._cxxflags: list[str] = list(cxxflags)
self._compiler_info = copy(compiler_info) if compiler_info is not None else GccInfo()
self._strict_scalar_types = strict_scalar_types
self._restrict_qualifier = "__restrict__"
if objcache is None:
from appdirs import AppDirs
......@@ -64,20 +62,12 @@ class CpuJit(JitBase):
return self._objcache
@property
def cxx(self) -> str:
return self._cxx
@cxx.setter
def cxx(self, path: str):
self._cxx = path
@property
def cxxflags(self) -> list[str]:
return self._cxxflags
def compiler_info(self) -> CompilerInfo:
return self._compiler_info
@cxxflags.setter
def cxxflags(self, flags: Sequence[str]):
self._cxxflags = list(flags)
@compiler_info.setter
def compiler_info(self, info: CompilerInfo):
self._compiler_info = info
@property
def strict_scalar_types(self) -> bool:
......@@ -139,9 +129,9 @@ class CpuJit(JitBase):
def _compile_extension_module(self, src_file: Path, libfile: Path):
args = (
[self._cxx]
[self._compiler_info.cxx()]
+ self._cxx_fixed_flags
+ self._cxxflags
+ self._compiler_info.cxxflags()
+ ["-o", str(libfile), str(src_file)]
)
......@@ -195,7 +185,7 @@ class KernelModuleBuilder:
templ = Template(_module_template.read_text())
code_str = templ.substitute(
includes="\n".join(includes),
restrict_qualifier=self._jit.restrict_qualifier,
restrict_qualifier=self._jit.compiler_info.restrict_qualifier(),
module_name=self._module_name,
kernel_name=kernel.function_name,
param_binds=", ".join(self._param_binds),
......
......@@ -5,7 +5,7 @@ from pystencils.jit import CpuJit
def test_basic_cpu_kernel(tmp_path):
jit = CpuJit("g++", ["-O3"], objcache=tmp_path)
jit = CpuJit(objcache=tmp_path)
f, g = fields("f, g: [2D]")
asm = Assignment(f.center(), 2.0 * g.center())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment