Skip to content
Snippets Groups Projects
Commit 9b9a4b54 authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'destructuring-field-binding' into 'master'

Destructuring field binding

See merge request !4
parents 8c4a6f1e 2313eda2
No related branches found
No related tags found
No related merge requests found
from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
from pystencils.field import Field
from pystencils.data_types import TypedSymbol, create_type, cast_func
from pystencils.kernelparameters import FieldStrideSymbol, FieldPointerSymbol, FieldShapeSymbol
from pystencils.kernelparameters import (FieldPointerSymbol, FieldShapeSymbol,
FieldStrideSymbol)
from pystencils.sympyextensions import fast_subs
from typing import List, Set, Optional, Union, Any, Sequence
NodeOrExpr = Union['Node', sp.Expr]
......@@ -130,6 +134,7 @@ class KernelFunction(Node):
defined in pystencils.kernelparameters.
If the parameter is related to one or multiple fields, these fields are referenced in the fields property.
"""
def __init__(self, symbol, fields):
self.symbol = symbol # type: TypedSymbol
self.fields = fields # type: Sequence[Field]
......@@ -582,6 +587,7 @@ class TemporaryMemoryAllocation(Node):
size: number of elements to allocate
align_offset: the align_offset's element is aligned
"""
def __init__(self, typed_symbol: TypedSymbol, size, align_offset):
super(TemporaryMemoryAllocation, self).__init__(parent=None)
self.symbol = typed_symbol
......@@ -639,3 +645,58 @@ class TemporaryMemoryFree(Node):
def early_out(condition):
from pystencils.cpu.vectorization import vec_all
return Conditional(vec_all(condition), Block([SkipIteration()]))
class DestructuringBindingsForFieldClass(Node):
"""
Defines all variables needed for describing a field (shape, pointer, strides)
"""
CLASS_TO_MEMBER_DICT = {
FieldPointerSymbol: "data",
FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride"
}
CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
self.headers = ['<Field.h>']
self.body = body
@property
def args(self) -> List[NodeOrExpr]:
"""Returns all arguments/children of this node."""
return set()
@property
def symbols_defined(self) -> Set[sp.Symbol]:
"""Set of symbols which are defined by this node."""
undefined_field_symbols = {s for s in self.body.undefined_symbols
if isinstance(s, (FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol))}
return undefined_field_symbols
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
field_map = {f.name: f for f in self.fields_accessed}
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
self.body.subs(subs_dict)
@property
def func(self):
return self.__class__
def atoms(self, arg_type) -> Set[Any]:
return self.body.atoms(arg_type) | {s for s in self.symbols_defined if isinstance(s, arg_type)}
import sympy as sp
from collections import namedtuple
from sympy.core import S
from typing import Set
import jinja2
import sympy as sp
from sympy.core import S
from sympy.printing.ccode import C89CodePrinter
from pystencils.cpu.vectorization import vec_any, vec_all
from pystencils.astnodes import (DestructuringBindingsForFieldClass,
KernelFunction, Node)
from pystencils.cpu.vectorization import vec_all, vec_any
from pystencils.data_types import (PointerType, VectorType, address_of,
cast_func, create_type, reinterpret_cast_func,
cast_func, create_type,
get_type_of_expression,
vector_memory_access)
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
reinterpret_cast_func, vector_memory_access)
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt,
fast_sqrt)
from pystencils.integer_functions import (bit_shift_left, bit_shift_right,
bitwise_and, bitwise_or, bitwise_xor,
int_div, int_power_of_2, modulo_ceil)
from pystencils.kernelparameters import FieldPointerSymbol
try:
from sympy.printing.ccode import C99CodePrinter as CCodePrinter
except ImportError:
from sympy.printing.ccode import CCodePrinter # for sympy versions < 1.1
from pystencils.integer_functions import bitwise_xor, bit_shift_right, bit_shift_left, bitwise_and, \
bitwise_or, modulo_ceil, int_div, int_power_of_2
from pystencils.astnodes import Node, KernelFunction
__all__ = ['generate_c', 'CustomCodeNode', 'PrintNode', 'get_headers', 'CustomSympyPrinter']
......@@ -255,6 +261,30 @@ class CBackend:
result += "else " + false_block
return result
def _print_DestructuringBindingsForFieldClass(self, node: Node):
# Define all undefined symbols
undefined_field_symbols = node.symbols_defined
destructuring_bindings = ["%s = %s.%s%s;" %
(u.name,
u.field_name if hasattr(u, 'field_name') else u.field_names[0],
DestructuringBindingsForFieldClass.CLASS_TO_MEMBER_DICT[u.__class__],
"" if type(u) == FieldPointerSymbol else ("[%i]" % u.coordinate))
for u in undefined_field_symbols
]
destructuring_bindings.sort() # only for code aesthetics
template = jinja2.Template(
"""{
{% for binding in bindings -%}
{{ binding | indent(3) }}
{% endfor -%}
{{ block | indent(3) }}
}
""")
code = template.render(bindings=destructuring_bindings,
block=self._print(node.body))
return code
# ------------------------------------------ Helper function & classes -------------------------------------------------
......
......@@ -108,7 +108,7 @@ class TypedSymbol(sp.Symbol):
obj = super(TypedSymbol, cls).__xnew__(cls, name)
try:
obj._dtype = create_type(dtype)
except TypeError:
except (TypeError, ValueError):
# on error keep the string
obj._dtype = dtype
return obj
......
......@@ -306,6 +306,10 @@ class Field(AbstractField):
def index_dimensions(self) -> int:
return len(self.shape) - len(self._layout)
@property
def ndim(self) -> int:
return len(self.shape)
@property
def layout(self):
return self._layout
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import sympy
import pystencils
from pystencils.astnodes import DestructuringBindingsForFieldClass
def test_destructuring_field_class():
z, x, y = pystencils.fields("z, y, x: [2d]")
normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
ast = pystencils.create_kernel(normal_assignments)
print(pystencils.show_code(ast))
ast.body = DestructuringBindingsForFieldClass(ast.body)
print(pystencils.show_code(ast))
def main():
test_destructuring_field_class()
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment