Skip to content
Snippets Groups Projects
Commit 74236fab authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Fix #10: Avoid jinja2 dependency

This commit avoid dependency of core pystencils on jinja2.
However this could make the printing of some AST-nodes less elegant.
parent 6942ed0b
No related branches found
No related tags found
No related merge requests found
import uuid import uuid
from typing import Any, List, Optional, Sequence, Set, Union from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type from pystencils.data_types import TypedSymbol, cast_func, create_type
...@@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -673,7 +672,7 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol: "shape[%i]", FieldShapeSymbol: "shape[%i]",
FieldStrideSymbol: "stride[%i]" FieldStrideSymbol: "stride[%i]"
} }
CLASS_NAME_TEMPLATE = jinja2.Template("PyStencilsField<{{ dtype }}, {{ ndim }}>") CLASS_NAME_TEMPLATE = "PyStencilsField<{dtype}, {ndim}>"
@property @property
def fields_accessed(self) -> Set['ResolvedFieldAccess']: def fields_accessed(self) -> Set['ResolvedFieldAccess']:
...@@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node): ...@@ -703,7 +702,7 @@ class DestructuringBindingsForFieldClass(Node):
undefined_field_symbols = self.symbols_defined undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')} corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')} corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&') return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.format(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \ for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols) (self.body.undefined_symbols - undefined_field_symbols)
......
from collections import namedtuple from collections import namedtuple
from typing import Set from typing import Set
import jinja2
import sympy as sp import sympy as sp
from sympy.core import S from sympy.core import S
from sympy.printing.ccode import C89CodePrinter from sympy.printing.ccode import C89CodePrinter
...@@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter ...@@ -9,11 +8,12 @@ from sympy.printing.ccode import C89CodePrinter
from pystencils.astnodes import KernelFunction, Node from pystencils.astnodes import KernelFunction, Node
from pystencils.cpu.vectorization import vec_all, vec_any from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import ( from pystencils.data_types import (
PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression, reinterpret_cast_func, PointerType, VectorType, address_of, cast_func, create_type, get_type_of_expression,
vector_memory_access) reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.integer_functions import ( from pystencils.integer_functions import (
bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor, int_div, int_power_of_2, modulo_ceil) bit_shift_left, bit_shift_right, bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.kernelparameters import FieldPointerSymbol
try: try:
...@@ -271,18 +271,11 @@ class CBackend: ...@@ -271,18 +271,11 @@ class CBackend:
for u in undefined_field_symbols for u in undefined_field_symbols
] ]
destructuring_bindings.sort() # only for code aesthetics destructuring_bindings.sort() # only for code aesthetics
template = jinja2.Template( return "{\n" + self._indent + \
"""{ ("\n" + self._indent).join(destructuring_bindings) + \
{% for binding in bindings -%} "\n" + self._indent + \
{{ binding | indent(3) }} ("\n" + self._indent).join(self._print(node.body).splitlines()) + \
{% endfor -%} "\n}"
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes ------------------------------------------------- # ------------------------------------------ Helper function & classes -------------------------------------------------
......
import sympy import sympy
import jinja2
import pystencils import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass from pystencils.astnodes import DestructuringBindingsForFieldClass
from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol from pystencils.kernelparameters import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol
def test_destructuring_field_class(): def test_destructuring_field_class():
...@@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass): ...@@ -28,12 +25,13 @@ class DestructuringEmojiClass(DestructuringBindingsForFieldClass):
FieldShapeSymbol: "😳_%i", FieldShapeSymbol: "😳_%i",
FieldStrideSymbol: "🥵_%i" FieldStrideSymbol: "🥵_%i"
} }
CLASS_NAME_TEMPLATE = jinja2.Template("🤯<{{ dtype }}, {{ ndim }}>") CLASS_NAME_TEMPLATE = "🤯<{dtype}, {ndim}>"
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.headers = [] self.headers = []
def test_destructuring_alternative_field_class(): def test_destructuring_alternative_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]") z, x, y = pystencils.fields("z, y, x: [2d]")
...@@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class(): ...@@ -44,6 +42,7 @@ def test_destructuring_alternative_field_class():
ast.body = DestructuringEmojiClass(ast.body) ast.body = DestructuringEmojiClass(ast.body)
print(pystencils.show_code(ast)) print(pystencils.show_code(ast))
def main(): def main():
test_destructuring_field_class() test_destructuring_field_class()
test_destructuring_alternative_field_class() test_destructuring_alternative_field_class()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment