Skip to content
Snippets Groups Projects
Commit 8b30de1d authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'avoid-jinja2-dependency' into 'master'

Fix #10: Avoid jinja2 dependency

Closes #10

See merge request !18
parents 6942ed0b 74236fab
Branches
Tags
No related merge requests found
import uuid
from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
......@@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride[%i]"
}
CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>")
CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
......@@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node):
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
......
from collections import namedtuple
from typing import Set
import jinja2
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
......@@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func,
vector_memory_access)
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil)
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
......@@ -271,18 +271,11 @@ class CBackend:
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
template = jinja2.Template(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
return "{\n" + self._indent + \
("\n" + self._indent).join(destructuring_bindings) + \
"\n" + self._indent + \
("\n" + self._indent).join(self._print(node.body).splitlines()) + \
"\n}"
# ------------------------------------------ Helper function & classes -------------------------------------------------
......
import sympy
import jinja2
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class():
......@@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i"
}
CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>")
CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>"
def __init__(self, node):
super().__init__(node)
self.headers = []
def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
......@@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class():
ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast))
def main():
test_destructuring_field_class()
test_destructuring_alternative_field_class()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment