Skip to content
Snippets Groups Projects

Destructuring field binding

Merged Stephan Seitz requested to merge seitz/pystencils:destructuring-field-binding into master
3 files
+ 15
3
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 10
2
from typing import Any, List, Optional, Sequence, Set, Union
import jinja2
import sympy as sp
from pystencils.data_types import TypedSymbol, cast_func, create_type
@@ -661,7 +662,12 @@ class DestructuringBindingsForFieldClass(Node):
FieldShapeSymbol: "shape",
FieldStrideSymbol: "stride"
}
CLASS_NAME = "Field"
CLASS_NAME_TEMPLATE = jinja2.Template("Field<{{ dtype }}, {{ ndim }}>")
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def __init__(self, body):
super(DestructuringBindingsForFieldClass, self).__init__()
@@ -682,10 +688,12 @@ class DestructuringBindingsForFieldClass(Node):
@property
def undefined_symbols(self) -> Set[sp.Symbol]:
field_map = {f.name: f for f in self.fields_accessed}
undefined_field_symbols = self.symbols_defined
corresponding_field_names = {s.field_name for s in undefined_field_symbols if hasattr(s, 'field_name')}
corresponding_field_names |= {s.field_names[0] for s in undefined_field_symbols if hasattr(s, 'field_names')}
return {TypedSymbol(f, self.CLASS_NAME + '&') for f in corresponding_field_names} | \
return {TypedSymbol(f, self.CLASS_NAME_TEMPLATE.render(dtype=field_map[f].dtype, ndim=field_map[f].ndim) + '&')
for f in corresponding_field_names} | \
(self.body.undefined_symbols - undefined_field_symbols)
def subs(self, subs_dict) -> None:
Loading