Skip to content
Snippets Groups Projects

Bugfix fields accessed for interpolator access

2 files
+ 21
17
Compare changes
  • Side-by-side
  • Inline
Files
2
+ 5
4
@@ -211,17 +211,18 @@ class KernelFunction(Node):
@@ -211,17 +211,18 @@ class KernelFunction(Node):
return self._body,
return self._body,
@property
@property
def fields_accessed(self) -> Set['ResolvedFieldAccess']:
def fields_accessed(self) -> Set[Field]:
"""Set of Field instances: fields which are accessed inside this kernel function"""
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
from pystencils.interpolation_astnodes import InterpolatorAccess
 
return set(o.field for o in itertools.chain(self.atoms(ResolvedFieldAccess), self.atoms(InterpolatorAccess)))
@property
@property
def fields_written(self) -> Set['ResolvedFieldAccess']:
def fields_written(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment)
assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
@property
def fields_read(self) -> Set['ResolvedFieldAccess']:
def fields_read(self) -> Set[Field]:
assignments = self.atoms(SympyAssignment)
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
for a in assignments))
Loading