Skip to content
Snippets Groups Projects

Add experimental half precison support

All threads resolved!

Files

+ 16
1
@@ -50,6 +50,7 @@ import platform
import shutil
import subprocess
import textwrap
import warnings
from collections import OrderedDict
from sysconfig import get_paths
from tempfile import TemporaryDirectory, NamedTemporaryFile
@@ -60,7 +61,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.typing import CastFunc, VectorType, VectorMemoryAccess
from pystencils.typing import BasicType, CastFunc, VectorType, VectorMemoryAccess
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, recursive_dict_update
@@ -520,7 +521,21 @@ class ExtensionModuleCode:
headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes:
for field in ast.fields_accessed:
if isinstance(field.dtype, BasicType) and field.dtype.is_half():
if not platform.machine() in ['arm64', 'aarch64']:
warnings.warn(f"The AST contains half precision data types but platform is: "
f"{platform.machine()}. Using half precision might not work properly on this "
f"platform")
if 'clang' not in get_compiler_config()['command']:
warnings.warn(f"The AST contains half precision data types but compiler is: "
f"{get_compiler_config()['command']}. Using half precision is only tested with "
f"the Clang compiler")
# Add the half precision header only if half precision numbers occur in the AST
headers.add('"half_precision.h"')
headers.update(get_headers(ast))
header_list = list(headers)
header_list.sort()
header_list.insert(0, '"Python.h"')
Loading