diff --git a/src/pystencils/codegen/config.py b/src/pystencils/codegen/config.py index 07dd2691ddc806e5ca941a40e6382482f9b37479..6587f16b7650d3e3e3f64116cf464ad4e50fa7c6 100644 --- a/src/pystencils/codegen/config.py +++ b/src/pystencils/codegen/config.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, TypeGuard from warnings import warn from abc import ABC from collections.abc import Collection -from typing import Sequence, Generic, TypeVar, Callable, Any, cast +from typing import TYPE_CHECKING, Sequence, Generic, TypeVar, Callable, Any, cast, TypeGuard from dataclasses import dataclass, InitVar, fields from .target import Target @@ -208,10 +207,7 @@ class Category(Generic[Category_T]): setattr(obj, self._lookup, cat.copy() if cat is not None else None) -class _AUTO_TYPE: - @staticmethod - def is_auto(val: Any) -> TypeGuard[_AUTO_TYPE]: - return val == AUTO +class _AUTO_TYPE: ... # noqa: E701 AUTO = _AUTO_TYPE() diff --git a/src/pystencils/jit/cpu/compiler_info.py b/src/pystencils/jit/cpu/compiler_info.py index 07a5eebbdcd61b69e85892631f754677a31e3230..6ed568d72d28b4254825bd3e84901f5f3b48c5be 100644 --- a/src/pystencils/jit/cpu/compiler_info.py +++ b/src/pystencils/jit/cpu/compiler_info.py @@ -1,4 +1,5 @@ from __future__ import annotations +from typing import Sequence from abc import ABC, abstractmethod from dataclasses import dataclass @@ -21,13 +22,19 @@ class CompilerInfo(ABC): @abstractmethod def cxxflags(self) -> list[str]: ... + @abstractmethod + def linker_flags(self) -> list[str]: ... + + @abstractmethod + def include_flags(self, include_dirs: Sequence[str]) -> list[str]: ... + @abstractmethod def restrict_qualifier(self) -> str: ... class _GnuLikeCliCompiler(CompilerInfo): def cxxflags(self) -> list[str]: - flags = ["-DNDEBUG", f"-std={self.cxx_standard}"] + flags = ["-DNDEBUG", f"-std={self.cxx_standard}", "-fPIC"] if self.optlevel is not None: flags.append(f"-O{self.optlevel}") @@ -48,6 +55,12 @@ class _GnuLikeCliCompiler(CompilerInfo): flags += ["-march=x86-64-v4", "-mavx512fp16"] return flags + + def linker_flags(self) -> list[str]: + return ["-shared"] + + def include_flags(self, include_dirs: Sequence[str]) -> list[str]: + return [f"-I{d}" for d in include_dirs] def restrict_qualifier(self) -> str: return "__restrict__" diff --git a/src/pystencils/jit/cpu/cpujit.py b/src/pystencils/jit/cpu/cpujit.py index eb1b60abec079b55a52d11fed1ed85ac1e85f017..64045d54951639b47e7beb3d2d99def8db6ce304 100644 --- a/src/pystencils/jit/cpu/cpujit.py +++ b/src/pystencils/jit/cpu/cpujit.py @@ -18,7 +18,7 @@ class CpuJit(JitBase): :Creation: In most cases, objects of this class should be instantiated using the `create` factory method. - + :Implementation Details: The `CpuJit` combines two separate components: @@ -66,19 +66,21 @@ class CpuJit(JitBase): # Include Directories import sysconfig - - python_include = sysconfig.get_path("include") - from ...include import get_pystencils_include_path - pystencils_include = get_pystencils_include_path() + include_dirs = [ + sysconfig.get_path("include"), + get_pystencils_include_path(), + ] + self._ext_module_builder.include_dirs() - self._cxx_fixed_flags = [ - "-shared", - "-fPIC", - f"-I{python_include}", - f"-I{pystencils_include}", - ] + self._ext_module_builder.include_flags() + # Compiler Flags + + self._cxx = self._compiler_info.cxx() + self._cxx_fixed_flags = ( + self._compiler_info.cxxflags() + + self._compiler_info.include_flags(include_dirs) + + self._compiler_info.linker_flags() + ) @property def objcache(self) -> Path | None: @@ -153,9 +155,8 @@ class CpuJit(JitBase): def _compile_extension_module(self, src_file: Path, libfile: Path): args = ( - [self._compiler_info.cxx()] + [self._cxx] + self._cxx_fixed_flags - + self._compiler_info.cxxflags() + ["-o", str(libfile), str(src_file)] ) @@ -182,10 +183,12 @@ class CpuJit(JitBase): class ExtensionModuleBuilderBase(ABC): @staticmethod @abstractmethod - def include_flags() -> list[str]: ... + def include_dirs() -> list[str]: ... @abstractmethod def __call__(self, kernel: Kernel, module_name: str) -> str: ... @abstractmethod - def get_wrapper(self, kernel: Kernel, extension_module: ModuleType) -> KernelWrapper: ... + def get_wrapper( + self, kernel: Kernel, extension_module: ModuleType + ) -> KernelWrapper: ... diff --git a/src/pystencils/jit/cpu/cpujit_pybind11.py b/src/pystencils/jit/cpu/cpujit_pybind11.py index ba9065c8a552c4ad70ce6179ce5339daef637253..2f43ef1965278eb3148fbfb448a9cfb001542adf 100644 --- a/src/pystencils/jit/cpu/cpujit_pybind11.py +++ b/src/pystencils/jit/cpu/cpujit_pybind11.py @@ -22,11 +22,11 @@ _module_template = Path(__file__).parent / "pybind11_kernel_module.tmpl.cpp" class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): @staticmethod - def include_flags() -> list[str]: + def include_dirs() -> list[str]: import pybind11 as pb11 pybind11_include = pb11.get_include() - return [f"-I{pybind11_include}"] + return [pybind11_include] def __init__( self,