Skip to content
Snippets Groups Projects
Commit b444ae25 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

work work

parent 3de207f4
No related branches found
No related tags found
No related merge requests found
import sympy as sp
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.types import TypedSymbol, DataType, get_type_from_sympy
from pystencils.types import TypedSymbol, DataType, get_type_from_sympy, _c_dtype_dict
class Node(object):
......@@ -294,7 +293,7 @@ class SympyAssignment(Node):
self._lhsSymbol = lhsSymbol
self.rhs = rhsTerm
self._isDeclaration = True
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, IndexedBase):
if isinstance(self._lhsSymbol, Field.Access) or isinstance(self._lhsSymbol, sp.IndexedBase):
self._isDeclaration = False
self._isConst = isConst
......@@ -393,8 +392,6 @@ class TemporaryMemoryFree(Node):
# TODO implement defined & undefinedSymbols
class Conversion(Node):
def __init__(self, child, dtype, parent=None):
super(Conversion, self).__init__(parent)
......@@ -421,9 +418,9 @@ class Conversion(Node):
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
return '(%s(%s))' % (repr(self.dtype), repr(self.args[0].dtype)) + repr(self.args)
# TODO everything which is not Atomic expression: Pow)
# TODO Pow
_expr_dict = {'Add': ' + ', 'Mul': ' * ', 'Pow': '**'}
......@@ -482,6 +479,8 @@ class Indexed(Expr):
def __init__(self, args, base, parent=None):
super(Indexed, self).__init__(args, parent)
self.base = base
#Get dtype from label, and unpointer it
self.dtype = DataType(base.label.dtype.dtype)
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
......@@ -492,7 +491,6 @@ class Number(Node, sp.AtomicExpr):
super(Number, self).__init__(parent)
self.dtype, self.value = get_type_from_sympy(number)
#TODO why does it have to be a tuple()?
self._args = tuple()
@property
......
......@@ -36,10 +36,10 @@ class LLVMPrinter(Printer):
self.tmp_var[name] = value
def _print_Number(self, n, **kwargs):
return ir.Constant(self.fp_type, float(n))
return ir.Constant(self.fp_type, n)
def _print_Float(self, expr):
return ir.Constant(self.fp_type, float(expr.p))
return ir.Constant(self.fp_type, expr.p)
def _print_Integer(self, expr):
return ir.Constant(self.integer, expr.p)
......@@ -134,6 +134,19 @@ class LLVMPrinter(Printer):
def _print_SympyAssignment(self, assignment):
expr = self._print(assignment.rhs)
def _print_Conversion(self, conversion):
to_dtype = conversion.dtype
from_dtype = conversion.args[0].dtype
print(to_dtype, from_dtype)
# fp -> int: fptosi
# int -> fp: sitofp
# ptr -> int: ptrtoint
# int -> ptr: inttoptr
# ?bitcast, ?addrspacecast
def _print_Indexed(self, indexed):
pass
# Should have a list of math library functions to validate this.
......
......@@ -60,11 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
print(code)
# print(code)
desympy_ast(code)
print(code)
# print(code)
insert_casts(code)
return code
\ No newline at end of file
return code
......@@ -552,24 +552,22 @@ def insert_casts(node):
:param node: ast which should be traversed
:return: node
"""
def add_conversion(node, dtype):
return node
for arg in node.args:
print(arg)
insert_casts(arg)
if isinstance(node, ast.Indexed):
node.dtype = node.base.label.dtype
#TODO revmove this
pass
elif isinstance(node, ast.Expr):
print(node)
print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args])
args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
target = args[0]
for i in range(len(args)):
args[i] = add_conversion(args[i], target.dtype)
if args[i].dtype != target.dtype:
args[i] = ast.Conversion(args[i], target.dtype, node)
node.args = args
node.dtype = target.dtype
print(node)
elif isinstance(node, ast.SympyAssignment):
if node.lhs.dtype != node.rhs.dtype:
node.replace(node.rhs, ast.Conversion(node.rhs, node.lhs.dtype))
elif isinstance(node, ast.LoopOverCoordinate):
pass
return node
......@@ -601,7 +599,7 @@ def desympy_ast(node):
#elif isinstance(arg, sp.containers.Tuple):
#
else:
print('Not transforming:', arg, type(arg))
print('Not transforming:', type(arg), arg)
for arg in node.args:
desympy_ast(arg)
return node
......
......@@ -82,7 +82,6 @@ def get_type_from_sympy(node):
raise TypeError(node, 'is not a sp.Number')
if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
# TODO when float?
return DataType('double'), float(node)
elif isinstance(node, sp.Integer):
return DataType('int'), int(node)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment