Skip to content
Snippets Groups Projects
Commit 7d1c7d11 authored by Markus Holzer's avatar Markus Holzer
Browse files

Minor clean up

parent f21634e6
No related branches found
No related tags found
1 merge request!226Fix field size
...@@ -500,7 +500,7 @@ class ExtensionModuleCode: ...@@ -500,7 +500,7 @@ class ExtensionModuleCode:
self._ast_nodes = [] self._ast_nodes = []
self._function_names = [] self._function_names = []
self._custom_backend = custom_backend self._custom_backend = custom_backend
self.code_string = str() self._code_string = str()
self._code_hash = None self._code_hash = None
def add_function(self, ast, name=None): def add_function(self, ast, name=None):
...@@ -508,7 +508,7 @@ class ExtensionModuleCode: ...@@ -508,7 +508,7 @@ class ExtensionModuleCode:
self._function_names.append(name if name is not None else ast.function_name) self._function_names.append(name if name is not None else ast.function_name)
def create_code_string(self, restrict_qualifier, function_prefix): def create_code_string(self, restrict_qualifier, function_prefix):
self.code_string = str() self._code_string = str()
headers = {'<math.h>', '<stdint.h>'} headers = {'<math.h>', '<stdint.h>'}
for ast in self._ast_nodes: for ast in self._ast_nodes:
...@@ -518,28 +518,28 @@ class ExtensionModuleCode: ...@@ -518,28 +518,28 @@ class ExtensionModuleCode:
header_list.insert(0, '"Python.h"') header_list.insert(0, '"Python.h"')
includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list]) includes = "\n".join(["#include %s" % (include_file,) for include_file in header_list])
self.code_string += includes self._code_string += includes
self.code_string += "\n" self._code_string += "\n"
self.code_string += f"#define RESTRICT {restrict_qualifier} \n" self._code_string += f"#define RESTRICT {restrict_qualifier} \n"
self.code_string += f"#define FUNC_PREFIX {function_prefix}" self._code_string += f"#define FUNC_PREFIX {function_prefix}"
self.code_string += "\n" self._code_string += "\n"
for ast, name in zip(self._ast_nodes, self._function_names): for ast, name in zip(self._ast_nodes, self._function_names):
old_name = ast.function_name old_name = ast.function_name
ast.function_name = "kernel_" + name ast.function_name = "kernel_" + name
self.code_string += generate_c(ast, custom_backend=self._custom_backend) self._code_string += generate_c(ast, custom_backend=self._custom_backend)
self.code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast) self._code_string += create_function_boilerplate_code(ast.get_parameters(), name, ast)
ast.function_name = old_name ast.function_name = old_name
self._code_hash = "mod_" + hashlib.sha256(self.code_string.encode()).hexdigest() self._code_hash = "mod_" + hashlib.sha256(self._code_string.encode()).hexdigest()
self.code_string += create_module_boilerplate_code(self._code_hash, self._function_names) self._code_string += create_module_boilerplate_code(self._code_hash, self._function_names)
def get_hash_of_code(self): def get_hash_of_code(self):
return self._code_hash return self._code_hash
def write_to_file(self, file): def write_to_file(self, file):
assert self.code_string, "The code must be generated first" assert self._code_string, "The code must be generated first"
print(self.code_string, file=file) print(self._code_string, file=file)
def compile_module(code, code_hash, base_dir): def compile_module(code, code_hash, base_dir):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment