Skip to content
Snippets Groups Projects

Destructuring field binding

Files

+ 40
9
import sympy as sp
from collections import namedtuple
from collections import namedtuple
from sympy.core import S
from typing import Set
from typing import Set
 
 
import jinja2
 
import sympy as sp
 
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.astnodes import (DestructuringBindingsForFieldClass,
from pystencils.fast_approximation import fast_division, fast_sqrt, fast_inv_sqrt
KernelFunction, Node)
 
from pystencils.cpu.vectorization import vec_all, vec_any
 
from pystencils.data_types import (PointerType, VectorType, cast_func,
 
create_type, get_type_of_expression,
 
reinterpret_cast_func, vector_memory_access)
 
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
 
fast_sqrt)
 
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
 
bitwise_and, bitwise_or, bitwise_xor,
 
int_div, int_power_of_2, modulo_ceil)
 
from pystencils.kernelparameters import FieldPointerSymbol
try:
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction
from pystencils.data_types import create_type, PointerType, get_type_of_expression, VectorType, cast_func, \
vector_memory_access, reinterpret_cast_func
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
@@ -224,6 +231,30 @@ class CBackend:
@@ -224,6 +231,30 @@ class CBackend:
result += "else " + false_block
result += "else " + false_block
return result
return result
 
def _print_DestructuringBindingsForFieldClass(self, node: Node):
 
# Define all undefined symbols
 
undefined_field_symbols = node.symbols_defined
 
destructuring_bindings = ["%s = %s.%s%s;" %
 
(u.name,
 
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
 
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
 
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
 
for u in undefined_field_symbols
 
]
 
destructuring_bindings.sort() # only for code aesthetics
 
template = jinja2.Template(
 
"""{
 
{% for binding in bindings -%}
 
{{ binding | indent(3) }}
 
{% endfor -%}
 
{{ block | indent(3) }}
 
}
 
 
""")
 
code = template.render(bindings=destructuring_bindings,
 
block=self._print(node.body))
 
return code
 
# ------------------------------------------ Helper function & classes -------------------------------------------------
# ------------------------------------------ Helper function & classes -------------------------------------------------
Loading