Skip to content
Snippets Groups Projects

Add experimental half precison support

Merged Markus Holzer requested to merge holzer/pystencils:halfP into master
Files
9
+ 22
19
@@ -43,26 +43,30 @@ Then 'cl.exe' is used to compile.
@@ -43,26 +43,30 @@ Then 'cl.exe' is used to compile.
For Windows compilers the qualifier should be ``__restrict``
For Windows compilers the qualifier should be ``__restrict``
"""
"""
 
from appdirs import user_cache_dir, user_config_dir
 
from collections import OrderedDict
import hashlib
import hashlib
 
import importlib.util
import json
import json
import os
import os
import platform
import platform
import shutil
import shutil
import subprocess
import subprocess
 
import sysconfig
 
import tempfile
import textwrap
import textwrap
from collections import OrderedDict
import time
from sysconfig import get_paths
import warnings
from tempfile import TemporaryDirectory, NamedTemporaryFile
import numpy as np
import numpy as np
from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.typing import CastFunc, VectorType, VectorMemoryAccess
from pystencils.cpu.msvc_detection import get_environment
from pystencils.include import get_pystencils_include_path
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.kernel_wrapper import KernelWrapper
 
from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
from pystencils.utils import atomic_file_write, recursive_dict_update
from pystencils.utils import atomic_file_write, recursive_dict_update
@@ -216,12 +220,11 @@ def read_config():
@@ -216,12 +220,11 @@ def read_config():
shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
shutil.rmtree(config['cache']['object_cache'], ignore_errors=True)
create_folder(config['cache']['object_cache'], False)
create_folder(config['cache']['object_cache'], False)
with NamedTemporaryFile('w', dir=os.path.dirname(cache_status_file), delete=False) as f:
with tempfile.NamedTemporaryFile('w', dir=os.path.dirname(cache_status_file), delete=False) as f:
json.dump(config['compiler'], f, indent=4)
json.dump(config['compiler'], f, indent=4)
os.replace(f.name, cache_status_file)
os.replace(f.name, cache_status_file)
if config['compiler']['os'] == 'windows':
if config['compiler']['os'] == 'windows':
from pystencils.cpu.msvc_detection import get_environment
msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
msvc_env = get_environment(config['compiler']['msvc_version'], config['compiler']['arch'])
if 'env' not in config['compiler']:
if 'env' not in config['compiler']:
config['compiler']['env'] = {}
config['compiler']['env'] = {}
@@ -470,18 +473,15 @@ def create_module_boilerplate_code(module_name, names):
@@ -470,18 +473,15 @@ def create_module_boilerplate_code(module_name, names):
def load_kernel_from_file(module_name, function_name, path):
def load_kernel_from_file(module_name, function_name, path):
from importlib.util import spec_from_file_location, module_from_spec
try:
try:
spec = spec_from_file_location(name=module_name, location=path)
spec = importlib.util.spec_from_file_location(name=module_name, location=path)
mod = module_from_spec(spec)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
spec.loader.exec_module(mod)
except ImportError:
except ImportError:
import time
import warnings
warnings.warn(f"Could not load {path}, trying on more time in 5 seconds ...")
warnings.warn(f"Could not load {path}, trying on more time in 5 seconds ...")
time.sleep(5)
time.sleep(5)
spec = spec_from_file_location(name=module_name, location=path)
spec = importlib.util.spec_from_file_location(name=module_name, location=path)
mod = module_from_spec(spec)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
spec.loader.exec_module(mod)
return getattr(mod, function_name)
return getattr(mod, function_name)
@@ -520,9 +520,13 @@ class ExtensionModuleCode:
@@ -520,9 +520,13 @@ class ExtensionModuleCode:
headers = {'<math.h>', '<stdint.h>'}
headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes:
for ast in self._ast_nodes:
 
for field in ast.fields_accessed:
 
if isinstance(field.dtype, BasicType) and field.dtype.is_half():
 
# Add the half precision header only if half precision numbers occur in the AST
 
headers.add('"half_precision.h"')
headers.update(get_headers(ast))
headers.update(get_headers(ast))
header_list = list(headers)
header_list.sort()
header_list = sorted(headers)
header_list.insert(0, '"Python.h"')
header_list.insert(0, '"Python.h"')
ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list
ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list
if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
@@ -559,7 +563,7 @@ def compile_module(code, code_hash, base_dir, compile_flags=None):
@@ -559,7 +563,7 @@ def compile_module(code, code_hash, base_dir, compile_flags=None):
compile_flags = []
compile_flags = []
compiler_config = get_compiler_config()
compiler_config = get_compiler_config()
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags
extra_flags = ['-I' + sysconfig.get_paths()['include'], '-I' + get_pystencils_include_path()] + compile_flags
if compiler_config['os'].lower() == 'windows':
if compiler_config['os'].lower() == 'windows':
lib_suffix = '.pyd'
lib_suffix = '.pyd'
@@ -593,7 +597,6 @@ def compile_module(code, code_hash, base_dir, compile_flags=None):
@@ -593,7 +597,6 @@ def compile_module(code, code_hash, base_dir, compile_flags=None):
# Linking
# Linking
if windows:
if windows:
import sysconfig
config_vars = sysconfig.get_config_vars()
config_vars = sysconfig.get_config_vars()
py_lib = os.path.join(config_vars["installed_base"], "libs",
py_lib = os.path.join(config_vars["installed_base"], "libs",
f"python{config_vars['py_version_nodot']}.lib")
f"python{config_vars['py_version_nodot']}.lib")
@@ -627,7 +630,7 @@ def compile_and_load(ast, custom_backend=None):
@@ -627,7 +630,7 @@ def compile_and_load(ast, custom_backend=None):
compile_flags = ast.instruction_set['compile_flags']
compile_flags = ast.instruction_set['compile_flags']
if cache_config['object_cache'] is False:
if cache_config['object_cache'] is False:
with TemporaryDirectory() as base_dir:
with tempfile.TemporaryDirectory() as base_dir:
lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags)
lib_file = compile_module(code, code_hash_str, base_dir, compile_flags=compile_flags)
result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
result = load_kernel_from_file(code_hash_str, ast.function_name, lib_file)
else:
else:
Loading