Skip to content
Snippets Groups Projects
Commit 53b20223 authored by Jan Hoenig's avatar Jan Hoenig
Browse files

TODO wtf is indexed????

parent eb9e40d7
No related branches found
No related tags found
No related merge requests found
import sympy as sp
from sympy.tensor import IndexedBase, Indexed
from pystencils.field import Field
from pystencils.types import TypedSymbol, DataType, _c_dtype_dict
from pystencils.types import TypedSymbol, DataType, get_type_from_sympy
class Node(object):
......@@ -481,11 +481,14 @@ class Indexed(Expr):
def __repr__(self):
return '%s[%s]' % (self.args[0], self.args[1])
class Number(Node):
class Number(Node, sp.AtomicExpr):
def __init__(self, number, parent=None):
super(Number, self).__init__(parent)
self._args = None
self.dtype = dtype
self.dtype, self.value = get_type_from_sympy(number)
#TODO why does it have to be a tuple()?
self._args = tuple()
@property
def args(self):
......@@ -503,6 +506,6 @@ class Number(Node):
raise set()
def __repr__(self):
return '(%s)' % (_c_dtype_dict(self.dtype)) + repr(self.args)
return repr(self.dtype) + repr(self.value)
......@@ -60,7 +60,9 @@ def createKernel(listOfEquations, functionName="kernel", typeForSymbol=None, spl
resolveFieldAccesses(code, readOnlyFields, fieldToBasePointerInfo=basePointerInfos)
moveConstantsBeforeLoop(code)
print(code)
desympy_ast(code)
print(code)
insert_casts(code)
......
......@@ -556,15 +556,20 @@ def insert_casts(node):
return node
for arg in node.args:
print(arg)
insert_casts(arg)
if isinstance(node, ast.Indexed):
pass
elif isinstance(node, ast.Expr):
args = sorted((arg.dtype for arg in node.args), key=attrgetter('ptr', 'dtype'))
print(node)
print([(arg, type(arg), arg.dtype, type(arg.dtype)) for arg in node.args])
args = sorted((arg for arg in node.args), key=attrgetter('dtype'))
target = args[0]
for i in range(len(args)):
args[i] = add_conversion(args[i], target.dtype)
node.args = args
node.dtype = target.dtype
print(node)
elif isinstance(node, ast.LoopOverCoordinate):
pass
return node
......@@ -577,16 +582,21 @@ def desympy_ast(node):
:param node: ast which should be traversed. Only node's children will be modified.
:return: (modified) node
"""
if node.args is None:
return node
for i in range(len(node.args)):
arg = node.args[i]
if isinstance(arg, sp.Add):
node.replace(arg, ast.Add(arg.args, node))
elif isinstance(arg, sp.Number):
node.replace(arg, ast.Number(arg, node))
elif isinstance(arg, sp.Mul):
node.replace(arg, ast.Mul(arg.args, node))
elif isinstance(arg, sp.Pow):
node.replace(arg, ast.Pow(arg.args, node))
elif isinstance(arg, sp.tensor.Indexed):
node.replace(arg, ast.Indexed(arg.args, node))
#elif isinstance(arg, )
for arg in node.args:
desympy_ast(arg)
return node
......@@ -64,5 +64,29 @@ class DataType(object):
else:
return False
def __gt__(self, other):
if self.ptr and not other.ptr:
return True
if self.dtype > other.dtype:
return True
def get_type_from_sympy(node):
return DataType('int')
\ No newline at end of file
# Rational, NumberSymbol?
# Zero, One, NegativeOne )= Integer
# Half )= Rational
# NAN, Infinity, Negative Inifinity,
# Exp1, Imaginary Unit, Pi, EulerGamma, Catalan, Golden Ratio
# Pow, Mul, Add, Mod, Relational
if not isinstance(node, sp.Number):
raise TypeError(node, 'is not a sp.Number')
if isinstance(node, sp.Float) or isinstance(node, sp.RealNumber):
# TODO when float?
return DataType('double'), float(node)
elif isinstance(node, sp.Integer):
return DataType('int'), int(node)
elif isinstance(node, sp.Rational):
raise NotImplementedError('Rationals are not supported yet')
else:
raise TypeError(node, ' is not a supported type!')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment