Skip to content
Snippets Groups Projects

Destructuring field binding

Merged Stephan Seitz requested to merge seitz/pystencils:destructuring-field-binding into master
3 files
+ 123
5
Compare changes
  • Side-by-side
  • Inline
Files
3
  • DestructuringBindingsForFieldClass defines all field-related variables
    in its subordinated block.
    However, it leaves a TypedSymbol of type 'Field' for each field
    undefined.
    By that trick we can generate kernels that accept structs as
    kernelparameters.
    Either to include a pystencils specific Field struct of the following
    definition:
    
    ```cpp
    template<DTYPE_T, DIMENSION>
    struct Field
    {
        DTYPE_T* data;
        std::array<DTYPE_T, DIMENSION> shape;
        std::array<DTYPE_T, DIMENSION> stride;
    }
    
    or to be able to destructure user defined types like `pybind11::array`,
    `at::Tensor`, `tensorflow::Tensor`
    
    ```
import sympy as sp
from collections import namedtuple
from collections import namedtuple
from sympy.core import S
from typing import Set
from typing import Set
 
 
import jinja2
 
import sympy as sp
 
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.cpu.vectorization import vec_any, vec_all
@@ -11,6 +13,12 @@ from pystencils.data_types import (PointerType, VectorType, address_of,
@@ -11,6 +13,12 @@ from pystencils.data_types import (PointerType, VectorType, address_of,
vector_memory_access)
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
 
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
 
fast_sqrt)
 
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
 
bitwise_and, bitwise_or, bitwise_xor,
 
int_div, int_power_of_2, modulo_ceil)
 
from pystencils.kernelparameters import FieldPointerSymbol
try:
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
except ImportError:
@@ -255,6 +263,29 @@ class CBackend:
@@ -255,6 +263,29 @@ class CBackend:
result += "else " + false_block
result += "else " + false_block
return result
return result
 
def _print_DestructuringBindingsForFieldClass(self, node: Node):
 
# Define all undefined symbols
 
undefined_field_symbols = node.symbols_defined
 
destructuring_bindings = ["%s = %s.%s%s;" %
 
(u.name,
 
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
 
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
 
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
 
for u in undefined_field_symbols
 
]
 
template = jinja2.Template(
 
"""{
 
{% for binding in bindings -%}
 
{{ binding | indent(3) }}
 
{% endfor -%}
 
{{ block | indent(3) }}
 
}
 
 
""")
 
code = template.render(bindings=destructuring_bindings,
 
block=self._print(node.body))
 
return code
 
# ------------------------------------------ Helper function & classes -------------------------------------------------
# ------------------------------------------ Helper function & classes -------------------------------------------------
Loading