Skip to content
Snippets Groups Projects
Commit 231fb6af authored by Jan Hönig's avatar Jan Hönig
Browse files

Added demo for spinodal decomposition, which serves as a test.

Fixed a severe bug. Renamed makePythonFunction of the llvm backend.
Deleted code duplicity.
parent b7d48508
No related branches found
No related tags found
No related merge requests found
...@@ -90,13 +90,13 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): ...@@ -90,13 +90,13 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
:return: kernel functor :return: kernel functor
""" """
# build up list of CType arguments # build up list of CType arguments
func = compileAndLoad(kernelFunctionNode)
func.restype = None
try: try:
args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict) args = buildCTypeArgumentList(kernelFunctionNode.parameters, argumentDict)
except KeyError: except KeyError:
# not all parameters specified yet # not all parameters specified yet
return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict) return makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func)
func = compileAndLoad(kernelFunctionNode)
func.restype = None
return lambda: func(*args) return lambda: func(*args)
...@@ -427,9 +427,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict): ...@@ -427,9 +427,7 @@ def buildCTypeArgumentList(parameterSpecification, argumentDict):
return ctArguments return ctArguments
def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict): def makePythonFunctionIncompleteParams(kernelFunctionNode, argumentDict, func):
func = compileAndLoad(kernelFunctionNode)
func.restype = None
parameters = kernelFunctionNode.parameters parameters = kernelFunctionNode.parameters
cache = {} cache = {}
......
from .kernelcreation import createKernel, createIndexedKernel from .kernelcreation import createKernel, createIndexedKernel
from .llvmjit import compileLLVM, generate_and_jit, Jit, make_python_function from .llvmjit import compileLLVM, generate_and_jit, Jit, makePythonFunction
from .llvm import generateLLVM from .llvm import generateLLVM
...@@ -71,7 +71,6 @@ class LLVMPrinter(Printer): ...@@ -71,7 +71,6 @@ class LLVMPrinter(Printer):
return val return val
def _print_Pow(self, expr): def _print_Pow(self, expr):
#print(expr)
base0 = self._print(expr.base) base0 = self._print(expr.base)
if expr.exp == S.NegativeOne: if expr.exp == S.NegativeOne:
return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0) return self.builder.fdiv(ir.Constant(self.fp_type, 1.0), base0)
...@@ -84,6 +83,8 @@ class LLVMPrinter(Printer): ...@@ -84,6 +83,8 @@ class LLVMPrinter(Printer):
return self.builder.call(fn, [base0], "sqrt") return self.builder.call(fn, [base0], "sqrt")
if expr.exp == 2: if expr.exp == 2:
return self.builder.fmul(base0, base0) return self.builder.fmul(base0, base0)
elif expr.exp == 3:
return self.builder.fmul(self.builder.fmul(base0, base0), base0)
exp0 = self._print(expr.exp) exp0 = self._print(expr.exp)
fn = self.ext_fn.get("pow") fn = self.ext_fn.get("pow")
......
...@@ -7,7 +7,7 @@ import shutil ...@@ -7,7 +7,7 @@ import shutil
from ..data_types import toCtypes, createType, ctypes_from_llvm from ..data_types import toCtypes, createType, ctypes_from_llvm
from .llvm import generateLLVM from .llvm import generateLLVM
from ..cpu.cpujit import buildCTypeArgumentList from ..cpu.cpujit import buildCTypeArgumentList, makePythonFunctionIncompleteParams
def generate_and_jit(ast): def generate_and_jit(ast):
...@@ -18,41 +18,18 @@ def generate_and_jit(ast): ...@@ -18,41 +18,18 @@ def generate_and_jit(ast):
return compileLLVM(gen.module) return compileLLVM(gen.module)
def make_python_function(ast, argumentDict={}, func=None): def makePythonFunction(ast, argumentDict={}, func=None):
if func is None:
jit = generate_and_jit(ast)
func = jit.get_function_ptr(ast.functionName)
try: try:
args = buildCTypeArgumentList(ast.parameters, argumentDict) args = buildCTypeArgumentList(ast.parameters, argumentDict)
except KeyError: except KeyError:
# not all parameters specified yet # not all parameters specified yet
return make_python_function_incomplete(ast, argumentDict, func) return makePythonFunctionIncompleteParams(ast, argumentDict, func)
if func is None:
jit = generate_and_jit(ast)
func = jit.get_function_ptr(ast.functionName)
return lambda: func(*args) return lambda: func(*args)
def make_python_function_incomplete(ast, argumentDict, func=None):
if func is None:
jit = generate_and_jit(ast)
func = jit.get_function_ptr(ast.functionName)
parameters = ast.parameters
cache = {}
def wrapper(**kwargs):
key = hash(tuple((k, id(v)) for k, v in kwargs.items()))
try:
args = cache[key]
func(*args)
except KeyError:
fullArguments = argumentDict.copy()
fullArguments.update(kwargs)
args = buildCTypeArgumentList(parameters, fullArguments)
cache[key] = args
func(*args)
return wrapper
def compileLLVM(module): def compileLLVM(module):
jit = Jit() jit = Jit()
jit.parse(module) jit.parse(module)
......
...@@ -51,8 +51,8 @@ def insertCasts(node): ...@@ -51,8 +51,8 @@ def insertCasts(node):
args = [] args = []
for arg in node.args: for arg in node.args:
args.append(insertCasts(arg)) args.append(insertCasts(arg))
# TODO indexed, SympyAssignment, LoopOverCoordinate, Pow # TODO indexed, SympyAssignment, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Pow): if node.func in (sp.Add, sp.Mul, sp.Pow): # TODO fix pow, don't cast integer on double
types = [getTypeOfExpression(arg) for arg in args] types = [getTypeOfExpression(arg) for arg in args]
assert len(types) > 0 assert len(types) > 0
target = collateTypes(types) target = collateTypes(types)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment