Skip to content
Snippets Groups Projects

WIP: Astnodes for interpolation

2 files
+ 178
0
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -366,6 +366,52 @@ class CustomSympyPrinter(CCodePrinter):
@@ -366,6 +366,52 @@ class CustomSympyPrinter(CCodePrinter):
else:
else:
return res
return res
 
def _print_Sum(self, expr):
 
template = """[&]() {{
 
{dtype} sum = ({dtype}) 0;
 
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
 
sum += {expr};
 
}}
 
return sum;
 
}}()"""
 
var = expr.limits[0][0]
 
start = expr.limits[0][1]
 
end = expr.limits[0][2]
 
code = template.format(
 
dtype=get_type_of_expression(expr.args[0]),
 
iterator_dtype='int',
 
var=self._print(var),
 
start=self._print(start),
 
end=self._print(end),
 
expr=self._print(expr.function),
 
increment=str(1),
 
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
 
)
 
return code
 
 
def _print_Product(self, expr):
 
template = """[&]() {{
 
{dtype} product = ({dtype}) 1;
 
for ( {iterator_dtype} {var} = {start}; {condition}; {var} += {increment} ) {{
 
product *= {expr};
 
}}
 
return product;
 
}}()"""
 
var = expr.limits[0][0]
 
start = expr.limits[0][1]
 
end = expr.limits[0][2]
 
code = template.format(
 
dtype=get_type_of_expression(expr.args[0]),
 
iterator_dtype='int',
 
var=self._print(var),
 
start=self._print(start),
 
end=self._print(end),
 
expr=self._print(expr.function),
 
increment=str(1),
 
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
 
)
 
return code
 
_print_Max = C89CodePrinter._print_Max
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
_print_Min = C89CodePrinter._print_Min
Loading