Skip to content
Snippets Groups Projects
Commit 48be020a authored by Frederik Hennig's avatar Frederik Hennig
Browse files

add experimental-jit mode to test suite

parent fc169886
Branches
Tags
1 merge request!445Object-Oriented CPU JIT API and Prototype Implementation
...@@ -3,6 +3,7 @@ import runpy ...@@ -3,6 +3,7 @@ import runpy
import sys import sys
import tempfile import tempfile
import warnings import warnings
import pathlib
import nbformat import nbformat
import pytest import pytest
...@@ -185,24 +186,10 @@ class IPyNbFile(pytest.File): ...@@ -185,24 +186,10 @@ class IPyNbFile(pytest.File):
pass pass
if pytest_version >= 70000: def pytest_collect_file(file_path: pathlib.Path, parent):
# Since pytest 7.0, usage of `py.path.local` is deprecated and `pathlib.Path` should be used instead glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
import pathlib if any(file_path.match(g) for g in glob_exprs):
return IPyNbFile.from_parent(path=file_path, parent=parent)
def pytest_collect_file(file_path: pathlib.Path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(file_path.match(g) for g in glob_exprs):
return IPyNbFile.from_parent(path=file_path, parent=parent)
else:
def pytest_collect_file(path, parent):
glob_exprs = ["*demo*.ipynb", "*tutorial*.ipynb", "test_*.ipynb"]
if any(path.fnmatch(g) for g in glob_exprs):
if pytest_version >= 50403:
return IPyNbFile.from_parent(fspath=path, parent=parent)
else:
return IPyNbFile(path, parent)
# Fixtures # Fixtures
......
...@@ -114,7 +114,7 @@ class CpuJit(JitBase): ...@@ -114,7 +114,7 @@ class CpuJit(JitBase):
""" """
# Get the Code # Get the Code
module_name = f"{kernel.function_name}_jit" module_name = f"{kernel.name}_jit"
cpp_code = self._ext_module_builder.render_module(kernel, module_name) cpp_code = self._ext_module_builder.render_module(kernel, module_name)
# Get compiler information # Get compiler information
...@@ -124,11 +124,12 @@ class CpuJit(JitBase): ...@@ -124,11 +124,12 @@ class CpuJit(JitBase):
lib_suffix = f"{so_abi}.so" lib_suffix = f"{so_abi}.so"
# Compute Code Hash # Compute Code Hash
code_utf8 = cpp_code.encode("utf-8") code_utf8: bytes = cpp_code.encode("utf-8")
compiler_utf8: bytes = (" ".join([self._cxx] + self._cxx_fixed_flags)).encode("utf-8")
import hashlib import hashlib
code_hash = hashlib.sha256(code_utf8) module_hash = hashlib.sha256(code_utf8 + compiler_utf8)
module_stem = f"module_{code_hash.hexdigest()}" module_stem = f"module_{module_hash.hexdigest()}"
def compile_and_load(module_dir: Path): def compile_and_load(module_dir: Path):
cpp_file = module_dir / f"{module_stem}.cpp" cpp_file = module_dir / f"{module_stem}.cpp"
......
...@@ -31,12 +31,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -31,12 +31,9 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
def __init__( def __init__(
self, self,
compiler_info: CompilerInfo, compiler_info: CompilerInfo,
strict_scalar_types: bool = False,
): ):
self._compiler_info = compiler_info self._compiler_info = compiler_info
self._strict_scalar_types = strict_scalar_types
self._actual_field_types: dict[Field, PsType] self._actual_field_types: dict[Field, PsType]
self._param_binds: list[str] self._param_binds: list[str]
self._public_params: list[str] self._public_params: list[str]
...@@ -63,7 +60,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -63,7 +60,7 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
includes="\n".join(includes), includes="\n".join(includes),
restrict_qualifier=self._compiler_info.restrict_qualifier(), restrict_qualifier=self._compiler_info.restrict_qualifier(),
module_name=module_name, module_name=module_name,
kernel_name=kernel.function_name, kernel_name=kernel.name,
param_binds=", ".join(self._param_binds), param_binds=", ".join(self._param_binds),
public_params=", ".join(self._public_params), public_params=", ".join(self._public_params),
param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "), param_check_lines=indent("\n".join(self._param_check_lines), prefix=" "),
...@@ -118,8 +115,6 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase): ...@@ -118,8 +115,6 @@ class Pybind11KernelModuleBuilder(ExtensionModuleBuilderBase):
def _add_scalar_param(self, sc_param: Parameter): def _add_scalar_param(self, sc_param: Parameter):
param_bind = f'py::arg("{sc_param.name}")' param_bind = f'py::arg("{sc_param.name}")'
if self._strict_scalar_types:
param_bind += ".noconvert()"
self._param_binds.append(param_bind) self._param_binds.append(param_bind)
kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}" kernel_param = f"{sc_param.dtype.c_string()} {sc_param.name}"
......
...@@ -15,7 +15,6 @@ by your tests: ...@@ -15,7 +15,6 @@ by your tests:
import pytest import pytest
from types import ModuleType from types import ModuleType
from dataclasses import replace
import pystencils as ps import pystencils as ps
...@@ -32,6 +31,14 @@ AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets() ...@@ -32,6 +31,14 @@ AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets()
TARGET_IDS = [t.name for t in AVAILABLE_TARGETS] TARGET_IDS = [t.name for t in AVAILABLE_TARGETS]
def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--experimental-cpu-jit",
dest="experimental_cpu_jit",
action="store_true"
)
@pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS) @pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS)
def target(request) -> ps.Target: def target(request) -> ps.Target:
"""Provides all code generation targets available on the current hardware""" """Provides all code generation targets available on the current hardware"""
...@@ -39,7 +46,7 @@ def target(request) -> ps.Target: ...@@ -39,7 +46,7 @@ def target(request) -> ps.Target:
@pytest.fixture @pytest.fixture
def gen_config(target: ps.Target): def gen_config(request: pytest.FixtureRequest, target: ps.Target):
"""Default codegen configuration for the current target. """Default codegen configuration for the current target.
For GPU targets, set default indexing options. For GPU targets, set default indexing options.
...@@ -52,6 +59,11 @@ def gen_config(target: ps.Target): ...@@ -52,6 +59,11 @@ def gen_config(target: ps.Target):
gen_config.cpu.vectorize.enable = True gen_config.cpu.vectorize.enable = True
gen_config.cpu.vectorize.assume_inner_stride_one = True gen_config.cpu.vectorize.assume_inner_stride_one = True
if target.is_cpu() and request.config.getoption("experimental_cpu_jit"):
from pystencils.jit.cpu import CpuJit, GccInfo
gen_config.jit = CpuJit.create(compiler_info=GccInfo(target=target))
return gen_config return gen_config
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment