Skip to content
Snippets Groups Projects
Commit 6aeab41e authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Quality of life improvements for astnodes

parent 0f33273b
Branches
Tags
1 merge request!56Interpolation 24.0.9
import collections.abc
import itertools
import uuid
from typing import Any, List, Optional, Sequence, Set, Union
......@@ -33,7 +35,7 @@ class Node:
raise NotImplementedError()
def subs(self, subs_dict) -> None:
"""Inplace! substitute, similar to sympy's but modifies the AST inplace."""
"""Inplace! Substitute, similar to sympy's but modifies the AST inplace."""
for a in self.args:
a.subs(subs_dict)
......@@ -102,7 +104,8 @@ class Conditional(Node):
result = self.true_block.undefined_symbols
if self.false_block:
result.update(self.false_block.undefined_symbols)
result.update(self.condition_expr.atoms(sp.Symbol))
if hasattr(self.condition_expr, 'atoms'):
result.update(self.condition_expr.atoms(sp.Symbol))
return result
def __str__(self):
......@@ -212,9 +215,16 @@ class KernelFunction(Node):
"""Set of Field instances: fields which are accessed inside this kernel function"""
return set(o.field for o in self.atoms(ResolvedFieldAccess))
def fields_written(self):
assigments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assigments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_written(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return {a.lhs.field for a in assignments if isinstance(a.lhs, ResolvedFieldAccess)}
@property
def fields_read(self) -> Set['ResolvedFieldAccess']:
assignments = self.atoms(SympyAssignment)
return set().union(itertools.chain.from_iterable([f.field for f in a.rhs.free_symbols if hasattr(f, 'field')]
for a in assignments))
def get_parameters(self) -> Sequence['KernelFunction.Parameter']:
"""Returns list of parameters for this function.
......@@ -283,8 +293,15 @@ class Block(Node):
a.subs(subs_dict)
def insert_front(self, node):
node.parent = self
self._nodes.insert(0, node)
if isinstance(node, collections.abc.Iterable):
node = list(node)
for n in node:
n.parent = self
self._nodes = node + self._nodes
else:
node.parent = self
self._nodes.insert(0, node)
def insert_before(self, new_node, insert_before):
new_node.parent = self
......@@ -485,7 +502,7 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = lhs_symbol
self.rhs = rhs_expr
self.rhs = sp.simplify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment