Skip to content
Snippets Groups Projects

WIP: Astnodes for interpolation

2 files
+ 180
0
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -366,6 +366,54 @@ class CustomSympyPrinter(CCodePrinter):
else:
return res
def _print_Sum(self, expr):
template = jinja2.Template(
"""[&]() {
{{dtype}} sum = ({{dtype}}) 0;
for ( {{iterator_dtype}} {{var}} = {{start}}; {{condition}}; {{var}} += {{increment}} ) {
sum += {{expr}};
}
return sum;
}()""")
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.render(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
def _print_Product(self, expr):
template = jinja2.Template(
"""[&]() {
{{dtype}} product = ({{dtype}}) 1;
for ( {{iterator_dtype}} {{var}} = {{start}}; {{condition}}; {{var}} += {{increment}} ) {
product *= {{expr}};
}
return product;
}()""")
var = expr.limits[0][0]
start = expr.limits[0][1]
end = expr.limits[0][2]
code = template.render(
dtype=get_type_of_expression(expr.args[0]),
iterator_dtype='int',
var=self._print(var),
start=self._print(start),
end=self._print(end),
expr=self._print(expr.function),
increment=str(1),
condition=self._print(var) + ' <= ' + self._print(end) # if start < end else '>='
)
return code
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
Loading