Skip to content
Snippets Groups Projects

Add experimental half precison support

Merged Markus Holzer requested to merge holzer/pystencils:halfP into master
Files
8
+ 12
8
@@ -43,24 +43,26 @@ Then 'cl.exe' is used to compile.
For Windows compilers the qualifier should be ``__restrict``
"""
from appdirs import user_cache_dir, user_config_dir
from collections import OrderedDict
import hashlib
import json
import os
import platform
import shutil
import subprocess
import textwrap
from collections import OrderedDict
from sysconfig import get_paths
from tempfile import TemporaryDirectory, NamedTemporaryFile
import textwrap
import time
import warnings
import numpy as np
from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.typing import CastFunc, VectorType, VectorMemoryAccess
from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, recursive_dict_update
@@ -476,8 +478,6 @@ def load_kernel_from_file(module_name, function_name, path):
mod = module_from_spec(spec)
spec.loader.exec_module(mod)
except ImportError:
import time
import warnings
warnings.warn(f"Could not load {path}, trying on more time in 5 seconds ...")
time.sleep(5)
spec = spec_from_file_location(name=module_name, location=path)
@@ -520,9 +520,13 @@ class ExtensionModuleCode:
headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes:
for field in ast.fields_accessed:
if isinstance(field.dtype, BasicType) and field.dtype.is_half():
# Add the half precision header only if half precision numbers occur in the AST
headers.add('"half_precision.h"')
headers.update(get_headers(ast))
header_list = list(headers)
header_list.sort()
header_list = sorted(headers)
header_list.insert(0, '"Python.h"')
ps_headers = [os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]) for h in header_list
if os.path.exists(os.path.join(os.path.dirname(__file__), '..', 'include', h[1:-1]))]
Loading