Skip to content
Snippets Groups Projects
Commit 4a8659e8 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

it actually somehow comiles

parent b444ae25
No related branches found
No related tags found
No related merge requests found
...@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr): ...@@ -511,4 +511,10 @@ class Number(Node, sp.AtomicExpr):
def __repr__(self): def __repr__(self):
return repr(self.value) return repr(self.value)
def __float__(self):
return float(self.value)
def __int__(self):
return int(self.value)
import llvmlite.ir as ir import llvmlite.ir as ir
import functools
from sympy.printing.printer import Printer from sympy.printing.printer import Printer
from sympy import S from sympy import S
# S is numbers? # S is numbers?
from pystencils.llvm.control_flow import Loop from pystencils.llvm.control_flow import Loop
from ..types import DataType
from ..astnodes import Indexed
def generateLLVM(ast_node): def generateLLVM(ast_node):
...@@ -25,6 +28,7 @@ class LLVMPrinter(Printer): ...@@ -25,6 +28,7 @@ class LLVMPrinter(Printer):
self.fp_type = ir.DoubleType() self.fp_type = ir.DoubleType()
self.fp_pointer = self.fp_type.as_pointer() self.fp_pointer = self.fp_type.as_pointer()
self.integer = ir.IntType(64) self.integer = ir.IntType(64)
self.integer_pointer = self.integer.as_pointer()
self.void = ir.VoidType() self.void = ir.VoidType()
self.module = module self.module = module
self.builder = builder self.builder = builder
...@@ -35,8 +39,13 @@ class LLVMPrinter(Printer): ...@@ -35,8 +39,13 @@ class LLVMPrinter(Printer):
def _add_tmp_var(self, name, value): def _add_tmp_var(self, name, value):
self.tmp_var[name] = value self.tmp_var[name] = value
def _print_Number(self, n, **kwargs): def _print_Number(self, n):
return ir.Constant(self.fp_type, n) if n.dtype == DataType("int"):
return ir.Constant(self.integer, int(n))
elif n.dtype == DataType("double"):
return ir.Constant(self.fp_type, float(n))
else:
raise NotImplementedError("Numbers can only have int and double", n)
def _print_Float(self, expr): def _print_Float(self, expr):
return ir.Constant(self.fp_type, expr.p) return ir.Constant(self.fp_type, expr.p)
...@@ -81,16 +90,23 @@ class LLVMPrinter(Printer): ...@@ -81,16 +90,23 @@ class LLVMPrinter(Printer):
def _print_Mul(self, expr): def _print_Mul(self, expr):
nodes = [self._print(a) for a in expr.args] nodes = [self._print(a) for a in expr.args]
e = nodes[0] e = nodes[0]
if expr.dtype == DataType('double'):
mul = self.builder.fmul
else: # int TODO others?
mul = self.builder.mul
for node in nodes[1:]: for node in nodes[1:]:
e = self.builder.fmul(e, node) e = mul(e, node)
return e return e
def _print_Add(self, expr): def _print_Add(self, expr):
nodes = [self._print(a) for a in expr.args] nodes = [self._print(a) for a in expr.args]
e = nodes[0] e = nodes[0]
if expr.dtype == DataType('double'):
add = self.builder.fadd
else: # int TODO others?
add = self.builder.add
for node in nodes[1:]: for node in nodes[1:]:
print(e, node) e = add(e, node)
e = self.builder.fadd(e, node)
return e return e
def _print_KernelFunction(self, function): def _print_KernelFunction(self, function):
...@@ -118,6 +134,7 @@ class LLVMPrinter(Printer): ...@@ -118,6 +134,7 @@ class LLVMPrinter(Printer):
block = fn.append_basic_block(name="entry") block = fn.append_basic_block(name="entry")
self.builder = ir.IRBuilder(block) self.builder = ir.IRBuilder(block)
self._print(function.body) self._print(function.body)
self.builder.ret_void()
self.fn = fn self.fn = fn
return fn return fn
...@@ -129,29 +146,47 @@ class LLVMPrinter(Printer): ...@@ -129,29 +146,47 @@ class LLVMPrinter(Printer):
with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step), with Loop(self.builder, self._print(loop.start), self._print(loop.stop), self._print(loop.step),
loop.loopCounterName, loop.loopCounterSymbol.name) as i: loop.loopCounterName, loop.loopCounterSymbol.name) as i:
self._add_tmp_var(loop.loopCounterSymbol, i) self._add_tmp_var(loop.loopCounterSymbol, i)
# TODO remove tmp var
self._print(loop.body) self._print(loop.body)
def _print_SympyAssignment(self, assignment): def _print_SympyAssignment(self, assignment):
expr = self._print(assignment.rhs) expr = self._print(assignment.rhs)
lhs = assignment.lhs
if isinstance(lhs, Indexed):
ptr = self._print(lhs.base.label)
index = self._print(lhs.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.store(expr, gep)
self.func_arg_map[assignment.lhs.name] = expr
return expr
def _print_Conversion(self, conversion): def _print_Conversion(self, conversion):
node = self._print(conversion.args[0])
to_dtype = conversion.dtype to_dtype = conversion.dtype
from_dtype = conversion.args[0].dtype from_dtype = conversion.args[0].dtype
print(to_dtype, from_dtype) # (From, to)
# fp -> int: fptosi decision = {
# int -> fp: sitofp (DataType("int"), DataType("double")): functools.partial(self.builder.sitofp, node, self.fp_type),
# ptr -> int: ptrtoint (DataType("double"), DataType("int")): functools.partial(self.builder.fptosi, node, self.integer),
# int -> ptr: inttoptr (DataType("double *"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
# ?bitcast, ?addrspacecast (DataType("int"), DataType("double *")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(DataType("double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(DataType("int"), DataType("double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
(DataType("const double * __restrict__"), DataType("int")): functools.partial(self.builder.ptrtoint, node, self.integer),
(DataType("int"), DataType("const double * __restrict__")): functools.partial(self.builder.inttoptr, node, self.fp_pointer),
}
# TODO float, const, restrict
# TODO bitcast, addrspacecast
return decision[(from_dtype, to_dtype)]()
def _print_Indexed(self, indexed): def _print_Indexed(self, indexed):
pass ptr = self._print(indexed.base.label)
index = self._print(indexed.args[1])
gep = self.builder.gep(ptr, [index])
return self.builder.load(gep, name=indexed.base.label.name)
# Should have a list of math library functions to validate this.
# TODO function calls
# Should have a list of math library functions to validate this.
# TODO delete this -> NO this should be a function call
def _print_Function(self, expr): def _print_Function(self, expr):
name = expr.func.__name__ name = expr.func.__name__
e0 = self._print(expr.args[0]) e0 = self._print(expr.args[0])
...@@ -163,5 +198,5 @@ class LLVMPrinter(Printer): ...@@ -163,5 +198,5 @@ class LLVMPrinter(Printer):
return self.builder.call(fn, [e0], name) return self.builder.call(fn, [e0], name)
def emptyPrinter(self, expr): def emptyPrinter(self, expr):
raise TypeError("Unsupported type for LLVM JIT conversion: %s" raise TypeError("Unsupported type for LLVM JIT conversion: %s %s"
% type(expr)) % type(expr), expr)
from .kernelcreation import createKernel from .kernelcreation import createKernel
from .jit import compileLLVM
\ No newline at end of file
import llvmlite.binding as llvm import llvmlite.binding as llvm
import logging.config import logging.config
logger = logging.getLogger(__name__)
def compileLLVM(module):
return Eval().compile(module)
class Eval(object): class Eval(object):
def __init__(self): def __init__(self):
...@@ -63,9 +69,3 @@ class Eval(object): ...@@ -63,9 +69,3 @@ class Eval(object):
# result = fptr(2, 3) # result = fptr(2, 3)
# print(result) # print(result)
return 0 return 0
if __name__ == "__main__":
logger = logging.getLogger(__name__)
else:
logger = logging.getLogger(__name__)
...@@ -70,6 +70,9 @@ class DataType(object): ...@@ -70,6 +70,9 @@ class DataType(object):
if self.dtype > other.dtype: if self.dtype > other.dtype:
return True return True
def __hash__(self):
return hash(repr(self))
def get_type_from_sympy(node): def get_type_from_sympy(node):
# Rational, NumberSymbol? # Rational, NumberSymbol?
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment