Skip to content
Snippets Groups Projects
Commit f21634e6 authored by Markus Holzer's avatar Markus Holzer
Browse files

New try

parent d8072b06
No related branches found
No related tags found
1 merge request!226Fix field size
......@@ -494,19 +494,22 @@ def run_compile_step(command):
class ExtensionModuleCode:
def __init__(self, module_name='generated', custom_backend=None, generated_code=None):
def __init__(self, module_name='generated', custom_backend=None):
self.module_name = module_name
self._ast_nodes = []
self._function_names = []
self._custom_backend = custom_backend
self._generated_code = generated_code
self.code_string = str()
self._code_hash = None
def add_function(self, ast, name=None):
self._ast_nodes.append(ast)
self._function_names.append(name if name is not None else ast.function_name)
def write_to_file(self, restrict_qualifier, function_prefix, file):
def create_code_string(self, restrict_qualifier, function_prefix):
self.code_string = str()
headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes:
headers.update(get_headers(ast))
......@@ -515,22 +518,28 @@ class ExtensionModuleCode:
header_list.insert(0, '"Python.h"')
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
print(includes, file=file)
print("\n", file=file)
print(f"#define RESTRICT {restrict_qualifier}", file=file)
print(f"#define FUNC_PREFIX {function_prefix}", file=file)
print("\n", file=file)
self.code_string += includes
self.code_string += "\n"
self.code_string += f"#define RESTRICT {restrict_qualifier} \n"
self.code_string += f"#define FUNC_PREFIX {function_prefix}"
self.code_string += "\n"
for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name
ast.function_name = "kernel_" + name
if self._generated_code:
print(self._generated_code, file=file)
else:
print(generate_c(ast, custom_backend=self._custom_backend), file=file)
print(create_function_boilerplate_code(ast.get_parameters(), name, ast), file=file)
self.code_string += generate_c(ast, custom_backend=self._custom_backend)
self.code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
ast.function_name = old_name
print(create_module_boilerplate_code(self.module_name, self._function_names), file=file)
self._code_hash = "mod_" + hashlib.sha256(self.code_string.encode()).hexdigest()
self.code_string += create_module_boilerplate_code(self._code_hash, self._function_names)
def get_hash_of_code(self):
return self._code_hash
def write_to_file(self, file):
assert self.code_string, "The code must be generated first"
print(self.code_string, file=file)
def compile_module(code, code_hash, base_dir):
......@@ -554,7 +563,7 @@ def compile_module(code, code_hash, base_dir):
if not os.path.exists(object_file):
with file_handle_for_atomic_write(src_file) as f:
code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f)
code.write_to_file(f)
if windows:
compile_cmd = ['cl.exe', '/c', '/EHsc'] + compiler_config['flags'].split()
......@@ -588,15 +597,15 @@ def compile_module(code, code_hash, base_dir):
def compile_and_load(ast, custom_backend=None):
cache_config = get_cache_config()
generated_code = generate_c(ast, dialect='c', custom_backend=custom_backend)
fields_accessed = str(ast.fields_accessed)
compiler_config = get_compiler_config()
function_prefix = '__declspec(dllexport)' if compiler_config['os'].lower() == 'windows' else ''
# Also the information of the field size should be contained in the hash string. Due to padding the generated code
# can look similar for different field sizes.
code_hash_str = "mod_" + hashlib.sha256((generated_code + fields_accessed).encode()).hexdigest()
code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend, generated_code=generated_code)
code = ExtensionModuleCode(custom_backend=custom_backend)
code.add_function(ast, ast.function_name)
code.create_code_string(compiler_config['restrict_qualifier'], function_prefix)
code_hash_str = code.get_hash_of_code()
if cache_config['object_cache'] is False:
with TemporaryDirectory() as base_dir:
lib_file = compile_module(code, code_hash_str, base_dir)
......
......@@ -5,7 +5,6 @@ import pytest
import pystencils as ps
from pystencils.astnodes import Block, Conditional
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.cpu.vectorization import vec_all, vec_any
......@@ -29,7 +28,7 @@ def test_vec_any():
kernel = ast.compile()
kernel(data=data_arr)
width = get_vector_instruction_set_x86(instruction_set=instruction_set)['width']
width = ast.instruction_set['width']
np.testing.assert_equal(data_arr[3:9, 0:width], 2.0)
......
......@@ -4,8 +4,7 @@ import numpy as np
import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
from pystencils.backends.x86_instruction_sets import get_vector_instruction_set_x86
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets, get_vector_instruction_set
from pystencils.data_types import cast_func, VectorType
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
......@@ -57,7 +56,7 @@ def test_vectorized_abs(instruction_set, dtype):
@pytest.mark.parametrize('instruction_set', supported_instruction_sets)
def test_alignment_and_correct_ghost_layers(instruction_set, dtype):
itemsize = 8 if dtype == 'double' else 4
alignment = get_vector_instruction_set_x86(dtype, instruction_set)['width'] * itemsize
alignment = get_vector_instruction_set(dtype, instruction_set)['width'] * itemsize
dtype = np.float64 if dtype == 'double' else np.float32
domain_size = (128, 128)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment