diff --git a/pystencils/cpu/cpujit.py b/pystencils/cpu/cpujit.py index 5b011535826fbb364595f75ea25f2922980ad6c0..658ed93f0e84c695e3a2f92c75f5a821d7595265 100644 --- a/pystencils/cpu/cpujit.py +++ b/pystencils/cpu/cpujit.py @@ -64,7 +64,7 @@ from pystencils.kernel_wrapper import KernelWrapper from pystencils.utils import atomic_file_write, file_handle_for_atomic_write, recursive_dict_update -def make_python_function(kernel_function_node, custom_backend=None): +def make_python_function(kernel_function_node, custom_backend=None, get_assembly=False): """ Creates C code from the abstract syntax tree, compiles it and makes it accessible as Python function @@ -73,9 +73,11 @@ def make_python_function(kernel_function_node, custom_backend=None): - all symbols which are not defined in the kernel itself are expected as parameters :param kernel_function_node: the abstract syntax tree + :param custom_backend: Use custom backend for the code generation + :param get_assembly: Compiles only the assembly file :return: kernel functor """ - result = compile_and_load(kernel_function_node, custom_backend) + result = compile_and_load(kernel_function_node, custom_backend, get_assembly) return result @@ -503,7 +505,7 @@ class ExtensionModuleCode: print(create_module_boilerplate_code(self.module_name, self._function_names), file=file) -def compile_module(code, code_hash, base_dir): +def compile_module(code, code_hash, base_dir, get_assembly=False): compiler_config = get_compiler_config() extra_flags = ['-I' + get_paths()['include'], '-I' + get_pystencils_include_path()] @@ -522,6 +524,22 @@ def compile_module(code, code_hash, base_dir): lib_file = os.path.join(base_dir, code_hash + lib_suffix) object_file = os.path.join(base_dir, code_hash + object_suffix) + if get_assembly: + assembly_file = os.path.join(base_dir, code_hash + ".s") + if not os.path.exists(assembly_file): + with file_handle_for_atomic_write(src_file) as f: + code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f) + if windows: + compile_cmd = ['cl.exe', '/Fa', '/EHsc'] + compiler_config['flags'].split() + compile_cmd += [*extra_flags, src_file, '/Fo' + assembly_file] + else: + compile_cmd = [compiler_config['command'], '-S'] + compiler_config['flags'].split() + compile_cmd += [*extra_flags, '-o', assembly_file, src_file] + run_compile_step(compile_cmd) + with open(assembly_file, "r") as assembly: + data = assembly.read() + return data + if not os.path.exists(object_file): with file_handle_for_atomic_write(src_file) as f: code.write_to_file(compiler_config['restrict_qualifier'], function_prefix, f) @@ -555,12 +573,15 @@ def compile_module(code, code_hash, base_dir): return lib_file -def compile_and_load(ast, custom_backend=None): +def compile_and_load(ast, custom_backend=None, get_assembly=False): cache_config = get_cache_config() code_hash_str = "mod_" + hashlib.sha256(generate_c(ast, dialect='c', custom_backend=custom_backend).encode()).hexdigest() code = ExtensionModuleCode(module_name=code_hash_str, custom_backend=custom_backend) code.add_function(ast, ast.function_name) + if get_assembly: + with TemporaryDirectory() as base_dir: + return compile_module(code, code_hash_str, base_dir=base_dir, get_assembly=get_assembly) if cache_config['object_cache'] is False: with TemporaryDirectory() as base_dir: