Skip to content
Snippets Groups Projects

Fix field size

Merged Markus Holzer requested to merge holzer/pystencils:FixFieldSize into master
Files
3
+ 54
18
@@ -59,6 +59,7 @@ from appdirs import user_cache_dir, user_config_dir
from pystencils import FieldType
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.data_types import cast_func, VectorType
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update
@@ -266,7 +267,6 @@ type_mapping = {
np.complex128: (('PyComplex_RealAsDouble', 'PyComplex_ImagAsDouble'), 'ComplexDouble'),
}
template_extract_scalar = """
PyObject * obj_{name} = PyDict_GetItemString(kwargs, "{name}");
if( obj_{name} == NULL) {{ PyErr_SetString(PyExc_TypeError, "Keyword argument '{name}' missing"); return NULL; }};
@@ -357,7 +357,7 @@ def equal_size_check(fields):
return template_size_check.format(cond=cond)
def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
def create_function_boilerplate_code(parameter_info, name, ast_node, insert_checks=True):
pre_call_code = ""
parameters = []
post_call_code = ""
@@ -375,6 +375,25 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
np_dtype = field.dtype.numpy_dtype
item_size = np_dtype.itemsize
aligned = False
if ast_node.assignments:
aligned = any([a.lhs.args[2] for a in ast_node.assignments
if hasattr(a, 'lhs') and isinstance(a.lhs, cast_func)
and hasattr(a.lhs, 'dtype') and isinstance(a.lhs.dtype, VectorType)])
if ast_node.instruction_set and aligned:
byte_width = ast_node.instruction_set['width'] * item_size
offset = max(max(ast_node.ghost_layers)) * item_size
offset_cond = f"(((uintptr_t) buffer_{field.name}.buf) + {offset}) % {byte_width} == 0"
message = str(offset) + ". This is probably due to a different number of ghost_layers chosen for " \
"the arrays and the kernel creation. If the number of ghost layers for " \
"the kernel creation is not specified it will choose a suitable value " \
"automatically. This value might not " \
"be compatible with the allocated arrays."
pre_call_code += template_check_array.format(cond=offset_cond, what="offset", name=field.name,
expected=message)
if (np_dtype.isbuiltin and FieldType.is_generic(field)
and not np.issubdtype(field.dtype.numpy_dtype, np.complexfloating)):
dtype_cond = "buffer_{name}.format[0] == '{format}'".format(name=field.name,
@@ -418,7 +437,7 @@ def create_function_boilerplate_code(parameter_info, name, insert_checks=True):
extract_function_imag=extract_function[1],
target_type=target_type,
real_type="float" if target_type == "ComplexFloat"
else "double",
else "double",
name=param.symbol.name)
else:
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
@@ -481,12 +500,16 @@ class ExtensionModuleCode:
self._ast_nodes = []
self._function_names = []
self._custom_backend = custom_backend
self._code_string = str()
self._code_hash = None
def add_function(self, ast, name=None):
self._ast_nodes.append(ast)
self._function_names.append(name if name is not None else ast.function_name)
def write_to_file(self, restrict_qualifier, function_prefix, file):
def create_code_string(self, restrict_qualifier, function_prefix):
self._code_string = str()
headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes:
headers.update(get_headers(ast))
@@ -495,19 +518,29 @@ class ExtensionModuleCode:
header_list.insert(0, '"Python.h"')
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
print(includes, file=file)
print("\n", file=file)
print(f"#define RESTRICT {restrict_qualifier}", file=file)
print(f"#define FUNC_PREFIX {function_prefix}", file=file)
print("\n", file=file)
self._code_string += includes
self._code_string += "\n"
self._code_string += f"#define RESTRICT {restrict_qualifier} \n"
self._code_string += f"#define FUNC_PREFIX {function_prefix}"
self._code_string += "\n"
for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name
ast.function_name = "kernel_" + name
print(generate_c(ast, custom_backend=self._custom_backend), file=file)
print(create_function_boilerplate_code(ast.get_parameters(), name), file=file)
self._code_string += generate_c(ast, custom_backend=self._custom_backend)
self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
ast.function_name = old_name
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode()).hexdigest()
self._code_string += create_module_boilerplate_code(self._code_hash, self._function_names)
def get_hash_of_code(self):
assert self._code_string, "The code must be generated first"
return self._code_hash
def write_to_file(self, file):
assert self._code_string, "The code must be generated first"
print(self._code_string, file=file)
def compile_module(code, code_hash, base_dir):
@@ -515,12 +548,10 @@ def compile_module(code, code_hash, base_dir):
extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()]
if compiler_config['os'].lower() == 'windows':
function_prefix = '__declspec(dllexport)'
lib_suffix = '.pyd'
object_suffix = '.obj'
windows = True
else:
function_prefix = ''
lib_suffix = '.so'
object_suffix = '.o'
windows = False
@@ -531,7 +562,7 @@ def compile_module(code, code_hash, base_dir):
if not os.path.exists(object_file):
with file_handle_for_atomic_write(src_file) as f:
code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)
code.write_to_file(f)
if windows:
compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
@@ -564,11 +595,16 @@ def compile_module(code, code_hash, base_dir):
def compile_and_load(ast, custom_backend=None):
cache_config = get_cache_config()
code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c',
custom_backend=custom_backend).encode()).hexdigest()
code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend)
compiler_config = get_compiler_config()
function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
code = ExtensionModuleCode(custom_backend=custom_backend)
code.add_function(ast, ast.function_name)
code.create_code_string(compiler_config['restrict_qualifier'], function_prefix)
code_hash_str = code.get_hash_of_code()
if cache_config['object_cache'] is False:
with TemporaryDirectory() as base_dir:
lib_file = compile_module(code, code_hash_str, base_dir)
Loading