Skip to content
Snippets Groups Projects

Destructuring field binding

Merged Stephan Seitz requested to merge seitz/pystencils:destructuring-field-binding into master
Files
5
import sympy as sp
from collections import namedtuple
from sympy.core import S
from typing import Set
import jinja2
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.astnodes import (DestructuringBindingsForFieldClass,
KernelFunction, Node)
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (PointerType, VectorType, address_of,
cast_func, create_type, reinterpret_cast_func,
cast_func, create_type,
get_type_of_expression,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
@@ -255,6 +261,30 @@ class CBackend:
result += "else " + false_block
return result
def _print_DestructuringBindingsForFieldClass(self, node: Node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" %
(u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
template = jinja2.Template(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes -------------------------------------------------
Loading