Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • const_fix
  • fhennig/compiler-warnings
  • fhennig/v2.0-deprecations
  • fma
  • gpu_bufferfield_fix
  • gpu_liveness_opts
  • holzer-master-patch-46757
  • hyteg
  • improved_comm
  • master
  • target_dh_refactoring
  • v2.0-dev
  • vectorization_sqrt_fix
  • zikeliml/124-rework-tutorials
  • zikeliml/Task-96-dotExporterForAST
  • last/Kerncraft
  • last/LLVM
  • last/OpenCL
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
  • release/0.3.0
  • release/0.3.1
  • release/0.3.2
  • release/0.3.3
  • release/0.3.4
  • release/0.4.0
  • release/0.4.1
  • release/0.4.2
  • release/0.4.3
  • release/0.4.4
  • release/1.0
  • release/1.0.1
  • release/1.1
  • release/1.1.1
  • release/1.2
  • release/1.3
  • release/1.3.1
  • release/1.3.2
  • release/1.3.3
  • release/1.3.4
  • release/1.3.5
  • release/1.3.6
  • release/1.3.7
  • release/2.0.dev0
57 results

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
  • armneon
  • compare_fix
  • const_fix
  • gpu_liveness_opts
  • hyteg
  • improved_comm
  • jan_fix
  • jan_test
  • master
  • mr_parallel_datahandling_fix
  • philox-simd
  • target_dh_refactoring
  • test_martin
  • test_martin2
  • vectorization_sqrt_fix
  • release/0.2.1
  • release/0.2.10
  • release/0.2.11
  • release/0.2.12
  • release/0.2.13
  • release/0.2.14
  • release/0.2.15
  • release/0.2.2
  • release/0.2.3
  • release/0.2.4
  • release/0.2.6
  • release/0.2.7
  • release/0.2.8
  • release/0.2.9
29 results
Show changes
Showing
with 1126 additions and 496 deletions
from typing import Any, Dict, List, Union, Optional, Set
import sympy
import sympy as sp
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.assignment import Assignment, AddAugmentedAssignment
import pystencils.astnodes as ast
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.functions import DivFunc
from pystencils.simp import AssignmentCollection
from pystencils.typing import FieldPointerSymbol
class NodeCollection:
def __init__(self, assignments: List[Union[ast.Node, Assignment]],
simplification_hints: Optional[Dict[str, Any]] = None,
bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None):
def visit(obj):
if isinstance(obj, (list, tuple)):
return [visit(e) for e in obj]
if isinstance(obj, Assignment):
if isinstance(obj.lhs, FieldPointerSymbol):
return ast.SympyAssignment(obj.lhs, obj.rhs, is_const=obj.lhs.dtype.const)
return ast.SympyAssignment(obj.lhs, obj.rhs)
elif isinstance(obj, AddAugmentedAssignment):
return ast.SympyAssignment(obj.lhs, obj.lhs + obj.rhs)
elif isinstance(obj, ast.SympyAssignment):
return obj
elif isinstance(obj, ast.Conditional):
true_block = visit(obj.true_block)
false_block = None if obj.false_block is None else visit(obj.false_block)
return ast.Conditional(obj.condition_expr, true_block=true_block, false_block=false_block)
elif isinstance(obj, ast.Block):
return ast.Block([visit(e) for e in obj.args])
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in the List of Assignments " + str(type(obj)))
self.all_assignments = visit(assignments)
self.simplification_hints = simplification_hints if simplification_hints else {}
self.bound_fields = bound_fields if bound_fields else {}
self.rhs_fields = rhs_fields if rhs_fields else {}
@staticmethod
def from_assignment_collection(assignment_collection: AssignmentCollection):
return NodeCollection(assignments=assignment_collection.all_assignments,
simplification_hints=assignment_collection.simplification_hints,
bound_fields=assignment_collection.bound_fields,
rhs_fields=assignment_collection.rhs_fields)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
def visitor(node):
if isinstance(node, CustomCodeNode):
return node
elif isinstance(node, ast.Block):
return node.func([visitor(child) for child in node.args])
elif isinstance(node, ast.SympyAssignment):
new_lhs = visitor(node.lhs)
new_rhs = visitor(node.rhs)
return node.func(new_lhs, new_rhs, node.is_const, node.use_auto)
elif isinstance(node, ast.Node):
return node.func(*[visitor(child) for child in node.args])
elif isinstance(node, sympy.Basic):
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
...@@ -34,7 +34,7 @@ def to_placeholder_function(expr, name): ...@@ -34,7 +34,7 @@ def to_placeholder_function(expr, name):
""" """
symbols = list(expr.atoms(sp.Symbol)) symbols = list(expr.atoms(sp.Symbol))
symbols.sort(key=lambda e: e.name) symbols.sort(key=lambda e: e.name)
derivative_symbols = [sp.Symbol("_d{}_d{}".format(name, s.name)) for s in symbols] derivative_symbols = [sp.Symbol(f"_d{name}_d{s.name}") for s in symbols]
derivatives = [sp.diff(expr, s) for s in symbols] derivatives = [sp.diff(expr, s) for s in symbols]
assignments = [Assignment(sp.Symbol(name), expr)] assignments = [Assignment(sp.Symbol(name), expr)]
......
File moved
import copy
import numpy as np
import sympy as sp
from pystencils.typing import TypedSymbol, CastFunc
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.sympyextensions import fast_subs
class RNGBase(CustomCodeNode):
id = 0
def __init__(self, dim, time_step=TypedSymbol("time_step", np.uint32), offsets=None, keys=None):
if keys is None:
keys = (0,) * self._num_keys
if offsets is None:
offsets = (0,) * dim
if len(keys) != self._num_keys:
raise ValueError(f"Provided {len(keys)} keys but need {self._num_keys}")
if len(offsets) != dim:
raise ValueError(f"Provided {len(offsets)} offsets but need {dim}")
coordinates = [LoopOverCoordinate.get_loop_counter_symbol(i) + offsets[i] for i in range(dim)]
if dim < 3:
coordinates.append(0)
self._args = sp.sympify([time_step, *coordinates, *keys])
self.result_symbols = tuple(TypedSymbol(f'random_{self.id}_{i}', self._data_type)
for i in range(self._num_vars))
symbols_read = set.union(*[s.atoms(sp.Symbol) for s in self.args])
super().__init__("", symbols_read=symbols_read, symbols_defined=self.result_symbols)
self.headers = [f'"{self._name.split("_")[0]}_rand.h"']
RNGBase.id += 1
@property
def args(self):
return self._args
def fast_subs(self, subs_dict, skip):
rng = copy.deepcopy(self)
rng._args = [fast_subs(a, subs_dict, skip) for a in rng._args]
return rng
def get_code(self, dialect, vector_instruction_set, print_arg):
code = "\n"
for r in self.result_symbols:
if vector_instruction_set and not self.args[1].atoms(CastFunc):
# this vector RNG has become scalar through substitution
code += f"{r.dtype} {r.name};\n"
else:
code += f"{vector_instruction_set[r.dtype.c_name] if vector_instruction_set else r.dtype} " + \
f"{r.name};\n"
args = [print_arg(a) for a in self.args] + ['' + r.name for r in self.result_symbols]
code += (self._name + "(" + ", ".join(args) + ");\n")
return code
def __repr__(self):
return ", ".join([str(s) for s in self.result_symbols]) + " \\leftarrow " + \
self._name.capitalize() + "_RNG(" + ", ".join([str(a) for a in self.args]) + ")"
def _hashable_content(self):
return (self._name, *self.result_symbols, *self.args)
def __eq__(self, other):
return type(self) is type(other) and self._hashable_content() == other._hashable_content()
def __hash__(self):
return hash(self._hashable_content())
class PhiloxTwoDoubles(RNGBase):
_name = "philox_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 2
class PhiloxFourFloats(RNGBase):
_name = "philox_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 2
class AESNITwoDoubles(RNGBase):
_name = "aesni_double2"
_data_type = np.float64
_num_vars = 2
_num_keys = 4
class AESNIFourFloats(RNGBase):
_name = "aesni_float4"
_data_type = np.float32
_num_vars = 4
_num_keys = 4
def random_symbol(assignment_list, dim, seed=TypedSymbol("seed", np.uint32), rng_node=PhiloxTwoDoubles,
time_step=TypedSymbol("time_step", np.uint32), offsets=None):
"""Return a symbol generator for random numbers
Args:
assignment_list: the subexpressions member of an AssignmentCollection, into which helper variables assignments
will be inserted
dim: 2 or 3 for two or three spatial dimensions
seed: an integer or TypedSymbol(..., np.uint32) to seed the random number generator. If you create multiple
symbol generators, please pass them different seeds so you don't get the same stream of random numbers!
rng_node: which random number generator to use (PhiloxTwoDoubles, PhiloxFourFloats, AESNITwoDoubles,
AESNIFourFloats).
time_step: TypedSymbol(..., np.uint32) that indicates the number of the current time step
offsets: tuple of offsets (constant integers or TypedSymbol(..., np.uint32)) that give the global coordinates
of the local origin
"""
counter = 0
while True:
keys = (counter, seed) + (0,) * (rng_node._num_keys - 2)
node = rng_node(dim, keys=keys, time_step=time_step, offsets=offsets)
inserted = False
for symbol in node.result_symbols:
if not inserted:
assignment_list.insert(0, node)
inserted = True
yield symbol
counter += 1
import socket import socket
import time import time
from types import MappingProxyType
from typing import Dict, Iterator, Sequence from typing import Dict, Iterator, Sequence
import blitzdb import blitzdb
import six
from blitzdb.backends.file.backend import serializer_classes
from blitzdb.backends.file.utils import JsonEncoder
from pystencils.cpu.cpujit import get_compiler_config from pystencils.cpu.cpujit import get_compiler_config
from pystencils import CreateKernelConfig, Target, Backend, Field
import json
import sympy as sp
from pystencils.typing import BasicType
class PystencilsJsonEncoder(JsonEncoder):
def default(self, obj):
if isinstance(obj, CreateKernelConfig):
return obj.__dict__
if isinstance(obj, (sp.Float, sp.Rational)):
return float(obj)
if isinstance(obj, sp.Integer):
return int(obj)
if isinstance(obj, (BasicType, MappingProxyType)):
return str(obj)
if isinstance(obj, (Target, Backend, sp.Symbol)):
return obj.name
if isinstance(obj, Field):
return f"pystencils.Field(name = {obj.name}, field_type = {obj.field_type.name}, " \
f"dtype = {str(obj.dtype)}, layout = {obj.layout}, shape = {obj.shape}, " \
f"strides = {obj.strides})"
return JsonEncoder.default(self, obj)
class PystencilsJsonSerializer(object):
@classmethod
def serialize(cls, data):
if six.PY3:
if isinstance(data, bytes):
return json.dumps(data.decode('utf-8'), cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
else:
return json.dumps(data, cls=PystencilsJsonEncoder, ensure_ascii=False).encode('utf-8')
@classmethod
def deserialize(cls, data):
if six.PY3:
return json.loads(data.decode('utf-8'))
else:
return json.loads(data.decode('utf-8'))
class Database: class Database:
...@@ -46,7 +96,7 @@ class Database: ...@@ -46,7 +96,7 @@ class Database:
class SimulationResult(blitzdb.Document): class SimulationResult(blitzdb.Document):
pass pass
def __init__(self, file: str) -> None: def __init__(self, file: str, serializer_info: tuple = None) -> None:
if file.startswith("mongo://"): if file.startswith("mongo://"):
from pymongo import MongoClient from pymongo import MongoClient
db_name = file[len("mongo://"):] db_name = file[len("mongo://"):]
...@@ -57,6 +107,10 @@ class Database: ...@@ -57,6 +107,10 @@ class Database:
self.backend.autocommit = True self.backend.autocommit = True
if serializer_info:
serializer_classes.update({serializer_info[0]: serializer_info[1]})
self.backend.load_config({'serializer_class': serializer_info[0]}, True)
def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None: def save(self, params: Dict, result: Dict, env: Dict = None, **kwargs) -> None:
"""Stores a simulation result in the database. """Stores a simulation result in the database.
...@@ -120,7 +174,7 @@ class Database: ...@@ -120,7 +174,7 @@ class Database:
Returns: Returns:
pandas data frame pandas data frame
""" """
from pandas.io.json import json_normalize from pandas import json_normalize
query_result = self.filter_params(parameter_query) query_result = self.filter_params(parameter_query)
attributes = [e.attributes for e in query_result] attributes = [e.attributes for e in query_result]
...@@ -146,10 +200,15 @@ class Database: ...@@ -146,10 +200,15 @@ class Database:
'cpuCompilerConfig': get_compiler_config(), 'cpuCompilerConfig': get_compiler_config(),
} }
try: try:
from git import Repo, InvalidGitRepositoryError from git import Repo
except ImportError:
return result
try:
from git import InvalidGitRepositoryError
repo = Repo(search_parent_directories=True) repo = Repo(search_parent_directories=True)
result['git_hash'] = str(repo.head.commit) result['git_hash'] = str(repo.head.commit)
except (ImportError, InvalidGitRepositoryError): except InvalidGitRepositoryError:
pass pass
return result return result
......
...@@ -9,6 +9,7 @@ from time import sleep ...@@ -9,6 +9,7 @@ from time import sleep
from typing import Any, Callable, Dict, Optional, Sequence, Tuple from typing import Any, Callable, Dict, Optional, Sequence, Tuple
from pystencils.runhelper import Database from pystencils.runhelper import Database
from pystencils.runhelper.db import PystencilsJsonSerializer
from pystencils.utils import DotDict from pystencils.utils import DotDict
ParameterDict = Dict[str, Any] ParameterDict = Dict[str, Any]
...@@ -54,10 +55,11 @@ class ParameterStudy: ...@@ -54,10 +55,11 @@ class ParameterStudy:
Run = namedtuple("Run", ['parameter_dict', 'weight']) Run = namedtuple("Run", ['parameter_dict', 'weight'])
def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (), def __init__(self, run_function: Callable[..., Dict], runs: Sequence = (),
database_connector: str = './db') -> None: database_connector: str = './db',
serializer_info: tuple = ('pystencils_serializer', PystencilsJsonSerializer)) -> None:
self.runs = list(runs) self.runs = list(runs)
self.run_function = run_function self.run_function = run_function
self.db = Database(database_connector) self.db = Database(database_connector, serializer_info)
def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None: def add_run(self, parameter_dict: ParameterDict, weight: int = 1) -> None:
"""Schedule a dictionary of parameters to run in this parameter study. """Schedule a dictionary of parameters to run in this parameter study.
...@@ -215,7 +217,7 @@ class ParameterStudy: ...@@ -215,7 +217,7 @@ class ParameterStudy:
def log_message(self, fmt, *args): def log_message(self, fmt, *args):
return return
print("Listening to connections on {}:{}. Scenarios to simulate: {}".format(ip, port, len(filtered_runs))) print(f"Listening to connections on {ip}:{port}. Scenarios to simulate: {len(filtered_runs)}")
server = HTTPServer((ip, port), ParameterStudyServer) server = HTTPServer((ip, port), ParameterStudyServer)
while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0: while len(ParameterStudyServer.currently_running) > 0 or len(ParameterStudyServer.runs) > 0:
server.handle_request() server.handle_request()
...@@ -241,7 +243,7 @@ class ParameterStudy: ...@@ -241,7 +243,7 @@ class ParameterStudy:
from urllib.error import URLError from urllib.error import URLError
import time import time
parameter_update = {} if parameter_update is None else parameter_update parameter_update = {} if parameter_update is None else parameter_update
url = "http://{}:{}".format(server, port) url = f"http://{server}:{port}"
client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid()) client_name = client_name.format(hostname=socket.gethostname(), pid=os.getpid())
start_time = time.time() start_time = time.time()
while True: while True:
...@@ -265,7 +267,7 @@ class ParameterStudy: ...@@ -265,7 +267,7 @@ class ParameterStudy:
'client_name': client_name} 'client_name': client_name}
urlopen(url + '/result', data=json.dumps(answer).encode()) urlopen(url + '/result', data=json.dumps(answer).encode())
except URLError: except URLError:
print("Cannot connect to server {} retrying in 5 seconds...".format(url)) print(f"Cannot connect to server {url} retrying in 5 seconds...")
sleep(5) sleep(5)
def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None: def run_from_command_line(self, argv: Optional[Sequence[str]] = None) -> None:
......
...@@ -2,8 +2,7 @@ import numpy as np ...@@ -2,8 +2,7 @@ import numpy as np
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
import pystencils.jupyter from pystencils.jupyter import make_imshow_animation, display_animation, set_display_mode
import pystencils.plot as plt import pystencils.plot as plt
import pystencils.sympy_gmpy_bug_workaround
__all__ = ['sp', 'np', 'ps', 'plt'] __all__ = ['sp', 'np', 'ps', 'plt', 'make_imshow_animation', 'display_animation', 'set_display_mode']
from .assignment_collection import AssignmentCollection from .assignment_collection import AssignmentCollection
from .simplifications import ( from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions, add_subexpressions_for_field_reads, add_subexpressions_for_divisions, add_subexpressions_for_field_reads,
apply_on_all_subexpressions, apply_to_all_assignments, add_subexpressions_for_sums, apply_on_all_subexpressions, apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions, subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list) subexpression_substitution_in_main_assignments, sympy_cse, sympy_cse_on_assignment_list)
from .subexpression_insertion import (
insert_aliases, insert_zeros, insert_constants,
insert_constant_additions, insert_constant_multiples,
insert_squares, insert_symbol_times_minus_one)
from .simplificationstrategy import SimplificationStrategy from .simplificationstrategy import SimplificationStrategy
__all__ = ['AssignmentCollection', 'SimplificationStrategy', __all__ = ['AssignmentCollection', 'SimplificationStrategy',
'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments', 'sympy_cse', 'sympy_cse_on_assignment_list', 'apply_to_all_assignments',
'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions', 'apply_on_all_subexpressions', 'subexpression_substitution_in_existing_subexpressions',
'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_divisions', 'subexpression_substitution_in_main_assignments', 'add_subexpressions_for_constants',
'add_subexpressions_for_field_reads'] 'add_subexpressions_for_divisions', 'add_subexpressions_for_sums', 'add_subexpressions_for_field_reads',
'insert_aliases', 'insert_zeros', 'insert_constants',
'insert_constant_additions', 'insert_constant_multiples',
'insert_squares', 'insert_symbol_times_minus_one']
import itertools
from copy import copy from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.simp.simplifications import ( from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -16,15 +17,16 @@ class AssignmentCollection: ...@@ -16,15 +17,16 @@ class AssignmentCollection:
These simplification methods can change the subexpressions, but the number and These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered. left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create Additionally a dictionary of simplification hints is stored, which are set by the functions that create
equation collections to transport information to the simplification system. assignment collections to transport information to the simplification system.
Attributes: Args:
main_assignments: list of assignments main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
subexpressions: list of assignments defining subexpressions used in main equations assignment is a field access. Thus the generated equations write on arrays.
simplification_hints: dict that is used to annotate the equation collection with hints that are subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning. potentially required hints and their meaning.
subexpression_symbol_generator: generator for new symbols that are used when new subexpressions are added subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection used to get new symbols that are unique for this AssignmentCollection
""" """
...@@ -32,9 +34,13 @@ class AssignmentCollection: ...@@ -32,9 +34,13 @@ class AssignmentCollection:
# ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- # ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = {}, subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None, simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict): if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v) main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()] for k, v in main_assignments.items()]
...@@ -42,6 +48,11 @@ class AssignmentCollection: ...@@ -42,6 +48,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v) subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()] for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments self.main_assignments = main_assignments
self.subexpressions = subexpressions self.subexpressions = subexpressions
...@@ -50,8 +61,11 @@ class AssignmentCollection: ...@@ -50,8 +61,11 @@ class AssignmentCollection:
self.simplification_hints = simplification_hints self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None: if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen() self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else: else:
self.subexpression_symbol_generator = subexpression_symbol_generator self.subexpression_symbol_generator = subexpression_symbol_generator
...@@ -95,32 +109,70 @@ class AssignmentCollection: ...@@ -95,32 +109,70 @@ class AssignmentCollection:
"""Subexpression and main equations as a single list.""" """Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments return self.subexpressions + self.main_assignments
@property
def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which occur on the rhs of any assignment."""
rhs_symbols = set()
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
elif isinstance(eq, pystencils.astnodes.Node):
rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
@property @property
def free_symbols(self) -> Set[sp.Symbol]: def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
free_symbols = set() return self.rhs_symbols - self.bound_symbols
for eq in self.all_assignments:
free_symbols.update(eq.rhs.atoms(sp.Symbol))
return free_symbols - self.bound_symbols
@property @property
def bound_symbols(self) -> Set[sp.Symbol]: def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression.""" """All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set([eq.lhs for eq in self.all_assignments]) bound_symbols_set = set(
assert len(bound_symbols_set) == len(self.subexpressions) + len(self.main_assignments), \ [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times" "Not in SSA form - same symbol assigned multiple times"
bound_symbols_set = bound_symbols_set.union(*[
assignment.symbols_defined for assignment in self.all_assignments
if isinstance(assignment, pystencils.astnodes.Node)
])
return bound_symbols_set return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
@property
def bound_fields(self):
"""All field accessed on the left hand side of a main assignment or a subexpression."""
return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
@property @property
def defined_symbols(self) -> Set[sp.Symbol]: def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations""" """All symbols which occur as left-hand-sides of one of the main equations"""
return set([assignment.lhs for assignment in self.main_assignments]) lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
if isinstance(assignment, pystencils.astnodes.Node)]))
@property @property
def operation_count(self): def operation_count(self):
"""See :func:`count_operations` """ """See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None) return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols. """Returns all symbols that depend on one of the passed symbols.
...@@ -172,6 +224,7 @@ class AssignmentCollection: ...@@ -172,6 +224,7 @@ class AssignmentCollection:
return {s: func(*args, **kwargs) for s, func in lambdas.items()} return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f return f
# ---------------------------- Creating new modified collections --------------------------------------------------- # ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self, def copy(self,
...@@ -225,7 +278,7 @@ class AssignmentCollection: ...@@ -225,7 +278,7 @@ class AssignmentCollection:
own_definitions = set([e.lhs for e in self.main_assignments]) own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments]) other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \ assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot new_merged, since both collection define the same symbols" "Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {} substitution_dict = {}
...@@ -233,12 +286,13 @@ class AssignmentCollection: ...@@ -233,12 +286,13 @@ class AssignmentCollection:
processed_other_subexpression_equations = [] processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions: for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols: if other_subexpression_eq.lhs in own_subexpression_symbols:
if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already continue # exact the same subexpression equation exists already
else: else:
# different definition - a new name has to be introduced # different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator) new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq) processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs substitution_dict[other_subexpression_eq.lhs] = new_lhs
else: else:
...@@ -261,9 +315,9 @@ class AssignmentCollection: ...@@ -261,9 +315,9 @@ class AssignmentCollection:
if eq.lhs in symbols_to_extract: if eq.lhs in symbols_to_extract:
new_assignments.append(eq) new_assignments.append(eq)
new_sub_expr = [eq for eq in self.subexpressions new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return AssignmentCollection(new_assignments, new_sub_expr) return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection': def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments.""" """Returns new collection that only contains subexpressions required to compute the main assignments."""
...@@ -286,8 +340,10 @@ class AssignmentCollection: ...@@ -286,8 +340,10 @@ class AssignmentCollection:
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions) return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted.""" """Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0: if len(self.subexpressions) == 0:
return self.copy() return self.copy()
...@@ -296,7 +352,7 @@ class AssignmentCollection: ...@@ -296,7 +352,7 @@ class AssignmentCollection:
kept_subexpressions = [] kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep: if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {} substitution_dict = {}
kept_subexpressions = self.subexpressions[0] kept_subexpressions.append(self.subexpressions[0])
else: else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
...@@ -315,6 +371,7 @@ class AssignmentCollection: ...@@ -315,6 +371,7 @@ class AssignmentCollection:
def _repr_html_(self): def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table""" """Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations): def make_html_equation_table(equations):
no_border = 'style="border:none"' no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">' html_table = '<table style="border:none; width: 100%; ">'
...@@ -335,19 +392,19 @@ class AssignmentCollection: ...@@ -335,19 +392,19 @@ class AssignmentCollection:
return result return result
def __repr__(self): def __repr__(self):
return "Equation Collection for " + ",".join([str(eq.lhs) for eq in self.main_assignments]) return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
def __str__(self): def __str__(self):
result = "Subexpressions:\n" result = "Subexpressions:\n"
for eq in self.subexpressions: for eq in self.subexpressions:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
result += "Main Assignments:\n" result += "Main Assignments:\n"
for eq in self.main_assignments: for eq in self.main_assignments:
result += "\t{eq}\n".format(eq=eq) result += f"\t{eq}\n"
return result return result
def __iter__(self): def __iter__(self):
return self.main_assignments.__iter__() return self.all_assignments.__iter__()
@property @property
def main_assignments_dict(self): def main_assignments_dict(self):
...@@ -393,18 +450,24 @@ class AssignmentCollection: ...@@ -393,18 +450,24 @@ class AssignmentCollection:
def __eq__(self, other): def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments) return set(self.all_assignments) == set(other.all_assignments)
def __bool__(self):
return bool(self.all_assignments)
class SymbolGen: class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ...""" """Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi"): def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = 0 self._ctr = ctr
self._symbol = symbol self._symbol = symbol
self._dtype = dtype
def __iter__(self): def __iter__(self):
return self return self
def __next__(self): def __next__(self):
name = "{}_{}".format(self._symbol, self._ctr) name = f"{self._symbol}_{self._ctr}"
self._ctr += 1 self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name) return sp.Symbol(name)
from itertools import chain from itertools import chain
from typing import Callable, List, Sequence, Union from typing import Callable, List, Sequence, Union
from collections import defaultdict
import sympy as sp import sympy as sp
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.field import AbstractField, Field from pystencils.field import Field
from pystencils.sympyextensions import subs_additive from pystencils.sympyextensions import subs_additive, is_constant, recursive_collect
from pystencils.typing import TypedSymbol
def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]: def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]]) -> List[Union[Assignment, Node]]:
...@@ -18,7 +20,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -18,7 +20,7 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
elif isinstance(e1, Node): elif isinstance(e1, Node):
symbols = e1.symbols_defined symbols = e1.symbols_defined
else: else:
raise NotImplementedError("Cannot sort topologically. Object of type " + type(e1) + " cannot be handled.") raise NotImplementedError(f"Cannot sort topologically. Object of type {type(e1)} cannot be handled.")
for lhs in symbols: for lhs in symbols:
for c2, e2 in enumerate(assignments): for c2, e2 in enumerate(assignments):
...@@ -29,18 +31,18 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node] ...@@ -29,18 +31,18 @@ def sort_assignments_topologically(assignments: Sequence[Union[Assignment, Node]
return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))] return [assignments[i] for i in sp.topological_sort((range(len(assignments)), edges))]
def sympy_cse(ac): def sympy_cse(ac, **kwargs):
"""Searches for common subexpressions inside the equation collection. """Searches for common subexpressions inside the assignment collection.
Searches is done in both the existing subexpressions as well as the assignments themselves. Searches is done in both the existing subexpressions as well as the assignments themselves.
It uses the sympy subexpression detection to do this. Return a new equation collection It uses the sympy subexpression detection to do this. Return a new assignment collection
with the additional subexpressions found with the additional subexpressions found
""" """
symbol_gen = ac.subexpression_symbol_generator symbol_gen = ac.subexpression_symbol_generator
all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)] all_assignments = [e for e in chain(ac.subexpressions, ac.main_assignments) if isinstance(e, Assignment)]
other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)] other_objects = [e for e in chain(ac.subexpressions, ac.main_assignments) if not isinstance(e, Assignment)]
replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen) replacements, new_eq = sp.cse(all_assignments, symbols=symbol_gen, **kwargs)
replacement_eqs = [Assignment(*r) for r in replacements] replacement_eqs = [Assignment(*r) for r in replacements]
...@@ -83,6 +85,39 @@ def subexpression_substitution_in_main_assignments(ac): ...@@ -83,6 +85,39 @@ def subexpression_substitution_in_main_assignments(ac):
return ac.copy(result) return ac.copy(result)
def add_subexpressions_for_constants(ac):
"""Extracts constant factors to subexpressions in the given assignment collection.
SymPy will exclude common factors from a sum only if they are symbols. This simplification
can be applied to exclude common numeric constants from multiple terms of a sum. As a consequence,
the number of multiplications is reduced and in some cases, more common subexpressions can be found.
"""
constants_to_subexp_dict = defaultdict(lambda: next(ac.subexpression_symbol_generator))
def visit(expr):
args = list(expr.args)
if len(args) == 0:
return expr
if isinstance(expr, sp.Add) or isinstance(expr, sp.Mul):
for i, arg in enumerate(args):
if is_constant(arg) and abs(arg) != 1:
if arg < 0:
args[i] = - constants_to_subexp_dict[- arg]
else:
args[i] = constants_to_subexp_dict[arg]
return expr.func(*(visit(a) for a in args))
main_assignments = [Assignment(a.lhs, visit(a.rhs)) for a in ac.main_assignments]
subexpressions = [Assignment(a.lhs, visit(a.rhs)) for a in ac.subexpressions]
symbols_to_collect = set(constants_to_subexp_dict.values())
main_assignments = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in main_assignments]
subexpressions = [Assignment(a.lhs, recursive_collect(a.rhs, symbols_to_collect, True)) for a in subexpressions]
subexpressions = [Assignment(symb, c) for c, symb in constants_to_subexp_dict.items()] + subexpressions
return ac.copy(main_assignments=main_assignments, subexpressions=subexpressions)
def add_subexpressions_for_divisions(ac): def add_subexpressions_for_divisions(ac):
r"""Introduces subexpressions for all divisions which have no constant in the denominator. r"""Introduces subexpressions for all divisions which have no constant in the denominator.
...@@ -112,14 +147,14 @@ def add_subexpressions_for_sums(ac): ...@@ -112,14 +147,14 @@ def add_subexpressions_for_sums(ac):
addends = [] addends = []
def contains_sum(term): def contains_sum(term):
if term.func == sp.add.Add: if term.func == sp.Add:
return True return True
if term.is_Atom: if term.is_Atom:
return False return False
return any([contains_sum(a) for a in term.args]) return any([contains_sum(a) for a in term.args])
def search_addends(term): def search_addends(term):
if term.func == sp.add.Add: if term.func == sp.Add:
if all([not contains_sum(a) for a in term.args]): if all([not contains_sum(a) for a in term.args]):
addends.extend(term.args) addends.extend(term.args)
for a in term.args: for a in term.args:
...@@ -128,18 +163,20 @@ def add_subexpressions_for_sums(ac): ...@@ -128,18 +163,20 @@ def add_subexpressions_for_sums(ac):
for eq in ac.all_assignments: for eq in ac.all_assignments:
search_addends(eq.rhs) search_addends(eq.rhs)
addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, AbstractField.AbstractAccess)] addends = [a for a in addends if not isinstance(a, sp.Symbol) or isinstance(a, Field.Access)]
new_symbol_gen = ac.subexpression_symbol_generator new_symbol_gen = ac.subexpression_symbol_generator
substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)} substitutions = {addend: new_symbol for new_symbol, addend in zip(new_symbol_gen, addends)}
return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False) return ac.new_with_substitutions(substitutions, True, substitute_on_lhs=False)
def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True): def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments=True, data_type=None):
r"""Substitutes field accesses on rhs of assignments with subexpressions r"""Substitutes field accesses on rhs of assignments with subexpressions
Can change semantics of the update rule (which is the goal of this transformation) Can change semantics of the update rule (which is the goal of this transformation)
This is useful if a field should be update in place - all values are loaded before into subexpression variables, This is useful if a field should be update in place - all values are loaded before into subexpression variables,
then the new values are computed and written to the same field in-place. then the new values are computed and written to the same field in-place.
Additionally, if a datatype is given to the function the rhs symbol of the new isolated field read will have
this data type. This is useful for mixed precision kernels
""" """
field_reads = set() field_reads = set()
to_iterate = [] to_iterate = []
...@@ -151,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments ...@@ -151,7 +188,17 @@ def add_subexpressions_for_field_reads(ac, subexpressions=True, main_assignments
for assignment in to_iterate: for assignment in to_iterate:
if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'): if hasattr(assignment, 'lhs') and hasattr(assignment, 'rhs'):
field_reads.update(assignment.rhs.atoms(Field.Access)) field_reads.update(assignment.rhs.atoms(Field.Access))
substitutions = {fa: next(ac.subexpression_symbol_generator) for fa in field_reads}
if not field_reads:
return ac
substitutions = dict()
for fa in field_reads:
lhs = next(ac.subexpression_symbol_generator)
if data_type is not None:
substitutions.update({fa: TypedSymbol(lhs.name, data_type)})
else:
substitutions.update({fa: lhs})
return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True, return ac.new_with_substitutions(substitutions, add_substitutions_as_subexpressions=True,
substitute_on_lhs=False, sort_topologically=False) substitute_on_lhs=False, sort_topologically=False)
...@@ -172,7 +219,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs): ...@@ -172,7 +219,7 @@ def transform_lhs_and_rhs(assignment_list, transformation, *args, **kwargs):
def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]): def apply_to_all_assignments(operation: Callable[[sp.Expr], sp.Expr]):
"""Applies sympy expand operation to all equations in collection.""" """Applies a given operation to all equations in collection."""
def f(ac): def f(ac):
return ac.copy(transform_rhs(ac.main_assignments, operation)) return ac.copy(transform_rhs(ac.main_assignments, operation))
...@@ -189,3 +236,31 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]): ...@@ -189,3 +236,31 @@ def apply_on_all_subexpressions(operation: Callable[[sp.Expr], sp.Expr]):
f.__name__ = operation.__name__ f.__name__ = operation.__name__
return f return f
# TODO Markus
# make this really work for Assignmentcollections
# this function should ONLY evaluate
# do the optims_c99 elsewhere optionally
# def apply_sympy_optimisations(ac: AssignmentCollection):
# """ Evaluates constant expressions (e.g. :math:`\\sqrt{3}` will be replaced by its floating point representation)
# and applies the default sympy optimisations. See sympy.codegen.rewriting
# """
#
# # Evaluates all constant terms
#
# assignments = ac.all_assignments
#
# evaluate_constant_terms = ReplaceOptim(lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
# lambda p: p.evalf())
#
# sympy_optimisations = [evaluate_constant_terms] + list(optims_c99)
#
# assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
# if hasattr(a, 'lhs')
# else a for a in assignments]
# assignments_nodes = [a.atoms(SympyAssignment) for a in assignments]
# for a in chain.from_iterable(assignments_nodes):
# a.optimize(sympy_optimisations)
#
# return AssignmentCollection(assignments)
...@@ -9,8 +9,8 @@ from pystencils.simp.assignment_collection import AssignmentCollection ...@@ -9,8 +9,8 @@ from pystencils.simp.assignment_collection import AssignmentCollection
class SimplificationStrategy: class SimplificationStrategy:
"""A simplification strategy is an ordered collection of simplification rules. """A simplification strategy is an ordered collection of simplification rules.
Each simplification is a function taking an equation collection, and returning a new simplified Each simplification is a function taking an assignment collection, and returning a new simplified
equation collection. The strategy can nicely print intermediate simplification stages and results assignment collection. The strategy can nicely print intermediate simplification stages and results
to Jupyter notebooks. to Jupyter notebooks.
""" """
...@@ -92,7 +92,7 @@ class SimplificationStrategy: ...@@ -92,7 +92,7 @@ class SimplificationStrategy:
assignment_collection = t(assignment_collection) assignment_collection = t(assignment_collection)
end_time = timeit.default_timer() end_time = timeit.default_timer()
op = assignment_collection.operation_count op = assignment_collection.operation_count
time_str = "%.2f ms" % ((end_time - start_time) * 1000,) time_str = f"{(end_time - start_time) * 1000:.2f} ms"
total = op['adds'] + op['muls'] + op['divs'] total = op['adds'] + op['muls'] + op['divs']
report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total)) report.add(ReportElement(t.__name__, time_str, op['adds'], op['muls'], op['divs'], total))
return report return report
...@@ -129,7 +129,7 @@ class SimplificationStrategy: ...@@ -129,7 +129,7 @@ class SimplificationStrategy:
def _repr_html_(self): def _repr_html_(self):
def print_assignment_collection(title, c): def print_assignment_collection(title, c):
text = '<h5 style="padding-bottom:10px">%s</h5> <div style="padding-left:20px;">' % (title, ) text = f'<h5 style="padding-bottom:10px">{title}</h5> <div style="padding-left:20px;">'
if self.restrict_symbols: if self.restrict_symbols:
text += "\n".join(["$$" + sp.latex(e) + '$$' text += "\n".join(["$$" + sp.latex(e) + '$$'
for e in c.new_filtered(self.restrict_symbols).main_assignments]) for e in c.new_filtered(self.restrict_symbols).main_assignments])
...@@ -151,5 +151,5 @@ class SimplificationStrategy: ...@@ -151,5 +151,5 @@ class SimplificationStrategy:
def __repr__(self): def __repr__(self):
result = "Simplification Strategy:\n" result = "Simplification Strategy:\n"
for t in self._rules: for t in self._rules:
result += " - %s\n" % (t.__name__,) result += f" - {t.__name__}\n"
return result return result
import sympy as sp
from pystencils.sympyextensions import is_constant
# Subexpression Insertion
def insert_subexpressions(ac, selection_callback, skip=None):
"""
Removes a number of subexpressions from an assignment collection by
inserting their right-hand side wherever they occur.
Args:
- selection_callback: Function that is called to qualify subexpressions
for insertion. Should return `True` for any subexpression that is to be
inserted, and `False` otherwise.
- skip: Set of symbols (left-hand sides of subexpressions) that should be
ignored even if qualified by the callback.
"""
if skip is None:
skip = set()
i = 0
while i < len(ac.subexpressions):
exp = ac.subexpressions[i]
if exp.lhs not in skip and selection_callback(exp):
ac = ac.new_with_inserted_subexpression(exp.lhs)
else:
i += 1
return ac
def insert_aliases(ac, **kwargs):
"""Inserts subexpressions that are aliases of other symbols,
i.e. their right-hand side is only another symbol."""
return insert_subexpressions(ac, lambda x: isinstance(x.rhs, sp.Symbol), **kwargs)
def insert_zeros(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is zero."""
zero = sp.Integer(0)
return insert_subexpressions(ac, lambda x: x.rhs == zero, **kwargs)
def insert_constants(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is constant,
i.e. contains no symbols."""
return insert_subexpressions(ac, lambda x: is_constant(x.rhs), **kwargs)
def insert_symbol_times_minus_one(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is just a
negation of another symbol."""
def callback(exp):
rhs = exp.rhs
minus_one = sp.Integer(-1)
atoms = rhs.atoms(sp.Symbol)
return len(atoms) == 1 and rhs == minus_one * atoms.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_multiples(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a constant
multiplied with another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() * symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_constant_additions(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is a sum of a
constant and another symbol."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
numbers = rhs.atoms(sp.Number)
return len(symbols) == 1 and len(numbers) == 1 and \
rhs == numbers.pop() + symbols.pop()
return insert_subexpressions(ac, callback, **kwargs)
def insert_squares(ac, **kwargs):
"""Inserts subexpressions whose right-hand side is another symbol squared."""
def callback(exp):
rhs = exp.rhs
symbols = rhs.atoms(sp.Symbol)
return len(symbols) == 1 and rhs == symbols.pop() ** 2
return insert_subexpressions(ac, callback, **kwargs)
def bind_symbols_to_skip(insertion_function, skip):
return lambda ac: insertion_function(ac, skip=skip)
from pystencils.simp import (SimplificationStrategy, insert_constants, insert_symbol_times_minus_one,
insert_constant_multiples, insert_constant_additions, insert_squares, insert_zeros)
def create_simplification_strategy():
"""
Creates a default simplification `ps.simp.SimplificationStrategy`. The idea behind the default simplification
strategy is to reduce the number of subexpressions by inserting single constants and to evaluate constant
terms beforehand.
"""
s = SimplificationStrategy()
s.add(insert_symbol_times_minus_one)
s.add(insert_constant_multiples)
s.add(insert_constant_additions)
s.add(insert_squares)
s.add(insert_zeros)
s.add(insert_constants)
s.add(lambda ac: ac.new_without_unused_subexpressions())
...@@ -89,9 +89,12 @@ def shift_slice(slices, offset): ...@@ -89,9 +89,12 @@ def shift_slice(slices, offset):
raise ValueError() raise ValueError()
if hasattr(offset, '__len__'): if hasattr(offset, '__len__'):
return [shift_slice_component(k, off) for k, off in zip(slices, offset)] return tuple(shift_slice_component(k, off) for k, off in zip(slices, offset))
else: else:
return [shift_slice_component(k, offset) for k in slices] if isinstance(slices, slice) or isinstance(slices, int) or isinstance(slices, float):
return shift_slice_component(slices, offset)
else:
return tuple(shift_slice_component(k, offset) for k in slices)
def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0): def slice_from_direction(direction_name, dim, normal_offset=0, tangential_offset=0):
......
import sympy
import pystencils
import pystencils.astnodes
x_, y_, z_ = tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(3))
x_staggered, y_staggered, z_staggered = x_ + 0.5, y_ + 0.5, z_ + 0.5
def x_vector(ndim):
return sympy.Matrix(tuple(pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(ndim)))
def x_staggered_vector(ndim):
return sympy.Matrix(tuple(
pystencils.astnodes.LoopOverCoordinate.get_loop_counter_symbol(i) + 0.5 for i in range(ndim)
))
...@@ -5,6 +5,8 @@ from typing import Sequence ...@@ -5,6 +5,8 @@ from typing import Sequence
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from pystencils.utils import binary_numbers
def inverse_direction(direction): def inverse_direction(direction):
"""Returns inverse i.e. negative of given direction tuple """Returns inverse i.e. negative of given direction tuple
...@@ -16,6 +18,11 @@ def inverse_direction(direction): ...@@ -16,6 +18,11 @@ def inverse_direction(direction):
return tuple([-i for i in direction]) return tuple([-i for i in direction])
def inverse_direction_string(direction):
"""Returns inverse of given direction string"""
return offset_to_direction_string(inverse_direction(direction_string_to_offset(direction)))
def is_valid(stencil, max_neighborhood=None): def is_valid(stencil, max_neighborhood=None):
""" """
Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length. Tests if a nested sequence is a valid stencil i.e. all the inner sequences have the same length.
...@@ -29,6 +36,8 @@ def is_valid(stencil, max_neighborhood=None): ...@@ -29,6 +36,8 @@ def is_valid(stencil, max_neighborhood=None):
True True
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=1) >>> is_valid([(2, 0), (1, 0)], max_neighborhood=1)
False False
>>> is_valid([(2, 0), (1, 0)], max_neighborhood=2)
True
""" """
expected_dim = len(stencil[0]) expected_dim = len(stencil[0])
for d in stencil: for d in stencil:
...@@ -62,8 +71,11 @@ def have_same_entries(s1, s2): ...@@ -62,8 +71,11 @@ def have_same_entries(s1, s2):
Examples: Examples:
>>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)] >>> stencil1 = [(1, 0), (-1, 0), (0, 1), (0, -1)]
>>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)] >>> stencil2 = [(-1, 0), (0, -1), (1, 0), (0, 1)]
>>> stencil3 = [(-1, 0), (0, -1), (1, 0)]
>>> have_same_entries(stencil1, stencil2) >>> have_same_entries(stencil1, stencil2)
True True
>>> have_same_entries(stencil1, stencil3)
False
""" """
if len(s1) != len(s2): if len(s1) != len(s2):
return False return False
...@@ -283,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3): ...@@ -283,6 +295,38 @@ def direction_string_to_offset(direction: str, dim: int = 3):
return offset[:dim] return offset[:dim]
def adjacent_directions(direction):
"""
Returns all adjacent directions for a direction as tuple of tuples. This is useful for exmple to find all directions
relevant for neighbour communication.
Args:
direction: tuple representing a direction. For example (0, 1, 0) for the northern side
Examples:
>>> adjacent_directions((0, 0, 0))
((0, 0, 0),)
>>> adjacent_directions((0, 1, 0))
((0, 1, 0),)
>>> adjacent_directions((0, 1, 1))
((0, 0, 1), (0, 1, 0), (0, 1, 1))
>>> adjacent_directions((-1, -1))
((-1, -1), (-1, 0), (0, -1))
"""
result = set()
if all(e == 0 for e in direction):
result.add(direction)
return tuple(result)
binary_numbers_list = binary_numbers(len(direction))
for adjacent_direction in binary_numbers_list:
for i, entry in enumerate(direction):
if entry == 0:
adjacent_direction[i] = 0
if entry == -1 and adjacent_direction[i] == 1:
adjacent_direction[i] = -1
if not all(e == 0 for e in adjacent_direction):
result.add(tuple(adjacent_direction))
return tuple(sorted(result))
# -------------------------------------- Visualization ----------------------------------------------------------------- # -------------------------------------- Visualization -----------------------------------------------------------------
...@@ -309,6 +353,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -309,6 +353,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
Args: Args:
stencil: sequence of directions stencil: sequence of directions
axes: optional matplotlib axes axes: optional matplotlib axes
figure: optional matplotlib figure
data: data to annotate the directions with, if none given, the indices are used data: data to annotate the directions with, if none given, the indices are used
textsize: size of annotation text textsize: size of annotation text
""" """
...@@ -322,15 +367,15 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -322,15 +367,15 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
text_box_style = BoxStyle("Round", pad=0.3) text_box_style = BoxStyle("Round", pad=0.3)
head_length = 0.1 head_length = 0.1
max_offsets = [max(abs(d[c]) for d in stencil) for c in (0, 1)] max_offsets = [max(abs(int(d[c])) for d in stencil) for c in (0, 1)]
if data is None: if data is None:
data = list(range(len(stencil))) data = list(range(len(stencil)))
for direction, annotation in zip(stencil, data): for direction, annotation in zip(stencil, data):
assert len(direction) == 2, "Works only for 2D stencils" assert len(direction) == 2, "Works only for 2D stencils"
direction = tuple(int(i) for i in direction)
if not(direction[0] == 0 and direction[1] == 0): if not (direction[0] == 0 and direction[1] == 0):
axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k') axes.arrow(0, 0, direction[0], direction[1], head_width=0.08, head_length=head_length, color='k')
if isinstance(annotation, sp.Basic): if isinstance(annotation, sp.Basic):
...@@ -346,7 +391,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs) ...@@ -346,7 +391,7 @@ def plot_2d(stencil, axes=None, figure=None, data=None, textsize='12', **kwargs)
else: else:
return 0 return 0
text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)] text_position = [direction[c] + position_correction(direction[c]) for c in (0, 1)]
axes.text(*text_position, annotation, verticalalignment='center', axes.text(x=text_position[0], y=text_position[1], s=annotation, verticalalignment='center',
zorder=30, horizontalalignment='center', size=textsize, zorder=30, horizontalalignment='center', size=textsize,
bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0)) bbox=dict(boxstyle=text_box_style, facecolor='#00b6eb', alpha=0.85, linewidth=0))
...@@ -364,6 +409,7 @@ def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs): ...@@ -364,6 +409,7 @@ def plot_3d_slicing(stencil, slice_axis=2, figure=None, data=None, **kwargs):
Args: Args:
stencil: stencil as sequence of directions stencil: stencil as sequence of directions
slice_axis: 0, 1, or 2 indicating the axis to slice through slice_axis: 0, 1, or 2 indicating the axis to slice through
figure: optional matplotlib figure
data: optional data to print as text besides the arrows data: optional data to print as text besides the arrows
""" """
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
...@@ -409,16 +455,17 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -409,16 +455,17 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs) FancyArrowPatch.__init__(self, (0, 0), (0, 0), *args, **kwargs)
self._verts3d = xs, ys, zs self._verts3d = xs, ys, zs
def draw(self, renderer): def do_3d_projection(self, *_):
xs3d, ys3d, zs3d = self._verts3d xs3d, ys3d, zs3d = self._verts3d
xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M) xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
FancyArrowPatch.draw(self, renderer)
return np.min(zs)
if axes is None: if axes is None:
if figure is None: if figure is None:
figure = plt.figure() figure = plt.figure()
axes = figure.gca(projection='3d') axes = figure.add_subplot(projection='3d')
try: try:
axes.set_aspect("equal") axes.set_aspect("equal")
except NotImplementedError: except NotImplementedError:
...@@ -434,10 +481,11 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -434,10 +481,11 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
r = [-1, 1] r = [-1, 1]
for s, e in combinations(np.array(list(product(r, r, r))), 2): for s, e in combinations(np.array(list(product(r, r, r))), 2):
if np.sum(np.abs(s - e)) == r[1] - r[0]: if np.sum(np.abs(s - e)) == r[1] - r[0]:
axes.plot3D(*zip(s, e), color="k", alpha=0.5) axes.plot(*zip(s, e), color="k", alpha=0.5)
for d, annotation in zip(stencil, data): for d, annotation in zip(stencil, data):
assert len(d) == 3, "Works only for 3D stencils" assert len(d) == 3, "Works only for 3D stencils"
d = tuple(int(i) for i in d)
if not (d[0] == 0 and d[1] == 0 and d[2] == 0): if not (d[0] == 0 and d[1] == 0 and d[2] == 0):
if d[0] == 0: if d[0] == 0:
color = '#348abd' color = '#348abd'
...@@ -457,8 +505,8 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'): ...@@ -457,8 +505,8 @@ def plot_3d(stencil, figure=None, axes=None, data=None, textsize='8'):
else: else:
annotation = str(annotation) annotation = str(annotation)
axes.text(d[0] * text_offset, d[1] * text_offset, d[2] * text_offset, axes.text(x=d[0] * text_offset, y=d[1] * text_offset, z=d[2] * text_offset,
annotation, verticalalignment='center', zorder=30, s=annotation, verticalalignment='center', zorder=30,
size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0)) size=textsize, bbox=dict(boxstyle=text_box_style, facecolor='#777777', alpha=0.6, linewidth=0))
axes.set_xlim([-text_offset * 1.1, text_offset * 1.1]) axes.set_xlim([-text_offset * 1.1, text_offset * 1.1])
......
...@@ -6,10 +6,14 @@ from functools import partial, reduce ...@@ -6,10 +6,14 @@ from functools import partial, reduce
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union from typing import Callable, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union
import sympy as sp import sympy as sp
from sympy import PolynomialError
from sympy.functions import Abs from sympy.functions import Abs
from sympy.core.numbers import Zero
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import cast_func, get_base_type, get_type_of_expression from pystencils.functions import DivFunc
from pystencils.typing import CastFunc, get_type_of_expression, PointerType, VectorType
from pystencils.typing.typed_sympy import FieldPointerSymbol
T = TypeVar('T') T = TypeVar('T')
...@@ -156,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict, ...@@ -156,17 +160,23 @@ def fast_subs(expression: T, substitutions: Dict,
if type(expression) is sp.Matrix: if type(expression) is sp.Matrix:
return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions)) return expression.copy().applyfunc(partial(fast_subs, substitutions=substitutions))
def visit(expr): def visit(expr, evaluate=True):
if skip and skip(expr): if skip and skip(expr):
return expr return expr
if hasattr(expr, "fast_subs"): elif hasattr(expr, "fast_subs"):
return expr.fast_subs(substitutions) return expr.fast_subs(substitutions, skip)
if expr in substitutions: elif expr in substitutions:
return substitutions[expr] return substitutions[expr]
if not hasattr(expr, 'args'): elif not hasattr(expr, 'args'):
return expr return expr
param_list = [visit(a) for a in expr.args] elif isinstance(expr, (sp.UnevaluatedExpr, DivFunc)):
return expr if not param_list else expr.func(*param_list) args = [visit(a, False) for a in expr.args]
return expr.func(*args)
else:
param_list = [visit(a, evaluate) for a in expr.args]
if isinstance(expr, (sp.Mul, sp.Add)):
return expr if not param_list else expr.func(*param_list, evaluate=evaluate)
return expr if not param_list else expr.func(*param_list)
if len(substitutions) == 0: if len(substitutions) == 0:
return expression return expression
...@@ -233,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -233,6 +243,9 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args)) normalized_replacement_match = normalize_match_parameter(required_match_replacement, len(subexpression.args))
if isinstance(subexpression, sp.Number):
return expr.subs({replacement: subexpression})
def visit(current_expr): def visit(current_expr):
if current_expr.is_Add: if current_expr.is_Add:
expr_max_length = max(len(current_expr.args), len(subexpression.args)) expr_max_length = max(len(current_expr.args), len(subexpression.args))
...@@ -260,8 +273,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -260,8 +273,8 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
if not param_list: if not param_list:
return current_expr return current_expr
else: else:
if current_expr.func == sp.Mul and sp.numbers.Zero() in param_list: if current_expr.func == sp.Mul and Zero() in param_list:
return sp.numbers.Zero() return sp.simplify(current_expr)
else: else:
return current_expr.func(*param_list, evaluate=False) return current_expr.func(*param_list, evaluate=False)
...@@ -271,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr, ...@@ -271,7 +284,7 @@ def subs_additive(expr: sp.Expr, replacement: sp.Expr, subexpression: sp.Expr,
def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol], def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Symbol],
positive: Optional[bool] = None, positive: Optional[bool] = None,
replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr: replace_mixed: Optional[List[Assignment]] = None) -> sp.Expr:
"""Replaces second order mixed terms like x*y by 2*( (x+y)**2 - x**2 - y**2 ). """Replaces second order mixed terms like 4*x*y by 2*( (x+y)**2 - x**2 - y**2 ).
This makes the term longer - simplify usually is undoing these - however this This makes the term longer - simplify usually is undoing these - however this
transformation can be done to find more common sub-expressions transformation can be done to find more common sub-expressions
...@@ -292,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym ...@@ -292,7 +305,7 @@ def replace_second_order_products(expr: sp.Expr, search_symbols: Iterable[sp.Sym
if expr.is_Mul: if expr.is_Mul:
distinct_search_symbols = set() distinct_search_symbols = set()
nr_of_search_terms = 0 nr_of_search_terms = 0
other_factors = 1 other_factors = sp.Integer(1)
for t in expr.args: for t in expr.args:
if t in search_symbols: if t in search_symbols:
nr_of_search_terms += 1 nr_of_search_terms += 1
...@@ -343,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -343,7 +356,7 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count = 0 factor_count = 0
if type(product) is Mul: if type(product) is Mul:
for factor in product.args: for factor in product.args:
if type(factor) == Pow: if type(factor) is Pow:
if factor.args[0] in symbols: if factor.args[0] in symbols:
factor_count += factor.args[1] factor_count += factor.args[1]
if factor in symbols: if factor in symbols:
...@@ -353,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order ...@@ -353,13 +366,13 @@ def remove_higher_order_terms(expr: sp.Expr, symbols: Sequence[sp.Symbol], order
factor_count += product.args[1] factor_count += product.args[1]
return factor_count return factor_count
if type(expr) == Mul or type(expr) == Pow: if type(expr) is Mul or type(expr) is Pow:
if velocity_factors_in_product(expr) <= order: if velocity_factors_in_product(expr) <= order:
return expr return expr
else: else:
return sp.Rational(0, 1) return Zero()
if type(expr) != Add: if type(expr) is not Add:
return expr return expr
for sum_term in expr.args: for sum_term in expr.args:
...@@ -429,7 +442,104 @@ def extract_most_common_factor(term): ...@@ -429,7 +442,104 @@ def extract_most_common_factor(term):
return common_factor, term / common_factor return common_factor, term / common_factor
def count_operations(term: Union[sp.Expr, List[sp.Expr]], def recursive_collect(expr, symbols, order_by_occurences=False):
"""Applies sympy.collect recursively for a list of symbols, collecting symbol 2 in the coefficients of symbol 1,
and so on.
``expr`` must be rewritable as a polynomial in the given ``symbols``.
It it is not, ``recursive_collect`` will fail quietly, returning the original expression.
Args:
expr: A sympy expression.
symbols: A sequence of symbols
order_by_occurences: If True, during recursive descent, always collect the symbol occuring
most often in the expression.
"""
if order_by_occurences:
symbols = list(expr.atoms(sp.Symbol) & set(symbols))
symbols = sorted(symbols, key=expr.count, reverse=True)
if len(symbols) == 0:
return expr
symbol = symbols[0]
collected = expr.collect(symbol)
try:
collected_poly = sp.Poly(collected, symbol)
except PolynomialError:
return expr
coeffs = collected_poly.all_coeffs()[::-1]
rec_sum = sum(symbol**i * recursive_collect(c, symbols[1:], order_by_occurences) for i, c in enumerate(coeffs))
return rec_sum
def summands(expr):
return set(expr.args) if isinstance(expr, sp.Add) else {expr}
def simplify_by_equality(expr, a, b, c):
"""
Uses the equality a = b + c, where a and b must be symbols, to simplify expr
by attempting to express additive combinations of two quantities by the third.
This works on expressions that are reducible to the form
:math:`a * (...) + b * (...) + c * (...)`,
without any mixed terms of a, b and c.
"""
if not isinstance(a, sp.Symbol) or not isinstance(b, sp.Symbol):
raise ValueError("a and b must be symbols.")
c = sp.sympify(c)
if not (isinstance(c, sp.Symbol) or is_constant(c)):
raise ValueError("c must be either a symbol or a constant!")
expr = sp.sympify(expr)
expr_expanded = sp.expand(expr)
a_coeff = expr_expanded.coeff(a, 1)
expr_expanded -= (a * a_coeff).expand()
b_coeff = expr_expanded.coeff(b, 1)
expr_expanded -= (b * b_coeff).expand()
if isinstance(c, sp.Symbol):
c_coeff = expr_expanded.coeff(c, 1)
rest = expr_expanded - (c * c_coeff).expand()
else:
c_coeff = expr_expanded / c
rest = 0
a_summands = summands(a_coeff)
b_summands = summands(b_coeff)
c_summands = summands(c_coeff)
# replace b + c by a
b_plus_c_coeffs = b_summands & c_summands
for coeff in b_plus_c_coeffs:
rest += a * coeff
b_summands -= b_plus_c_coeffs
c_summands -= b_plus_c_coeffs
# replace a - b by c
neg_b_summands = {-x for x in b_summands}
a_minus_b_coeffs = a_summands & neg_b_summands
for coeff in a_minus_b_coeffs:
rest += c * coeff
a_summands -= a_minus_b_coeffs
b_summands -= {-x for x in a_minus_b_coeffs}
# replace a - c by b
neg_c_summands = {-x for x in c_summands}
a_minus_c_coeffs = a_summands & neg_c_summands
for coeff in a_minus_c_coeffs:
rest += b * coeff
a_summands -= a_minus_c_coeffs
c_summands -= {-x for x in a_minus_c_coeffs}
# put it back together
return (rest + a * sum(a_summands) + b * sum(b_summands) + c * sum(c_summands)).expand()
def count_operations(term: Union[sp.Expr, List[sp.Expr], List[Assignment]],
only_type: Optional[str] = 'real') -> Dict[str, int]: only_type: Optional[str] = 'real') -> Dict[str, int]:
"""Counts the number of additions, multiplications and division. """Counts the number of additions, multiplications and division.
...@@ -444,7 +554,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -444,7 +554,6 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0, result = {'adds': 0, 'muls': 0, 'divs': 0, 'sqrts': 0,
'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0} 'fast_sqrts': 0, 'fast_inv_sqrts': 0, 'fast_div': 0}
if isinstance(term, Sequence): if isinstance(term, Sequence):
for element in term: for element in term:
r = count_operations(element, only_type) r = count_operations(element, only_type)
...@@ -454,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -454,16 +563,20 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
elif isinstance(term, Assignment): elif isinstance(term, Assignment):
term = term.rhs term = term.rhs
if hasattr(term, 'evalf'):
term = term.evalf()
def check_type(e): def check_type(e):
if only_type is None: if only_type is None:
return True return True
if isinstance(e, FieldPointerSymbol) and only_type == "real":
return only_type == "int"
try: try:
base_type = get_base_type(get_type_of_expression(e)) base_type = get_type_of_expression(e)
except ValueError: except ValueError:
return False return False
if isinstance(base_type, VectorType):
return False
if isinstance(base_type, PointerType):
return only_type == 'int'
if only_type == 'int' and (base_type.is_int() or base_type.is_uint()): if only_type == 'int' and (base_type.is_int() or base_type.is_uint()):
return True return True
if only_type == 'real' and (base_type.is_float()): if only_type == 'real' and (base_type.is_float()):
...@@ -492,7 +605,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -492,7 +605,7 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
visit_children = False visit_children = False
elif t.is_integer: elif t.is_integer:
pass pass
elif isinstance(t, cast_func): elif isinstance(t, CastFunc):
visit_children = False visit_children = False
visit(t.args[0]) visit(t.args[0])
elif t.func is fast_sqrt: elif t.func is fast_sqrt:
...@@ -503,18 +616,22 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -503,18 +616,22 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
result['fast_div'] += 1 result['fast_div'] += 1
elif t.func is sp.Pow: elif t.func is sp.Pow:
if check_type(t.args[0]): if check_type(t.args[0]):
visit_children = False visit_children = True
if t.exp.is_integer and t.exp.is_number: if t.exp.is_integer and t.exp.is_number:
if t.exp >= 0: if t.exp >= 0:
result['muls'] += int(t.exp) - 1 result['muls'] += int(t.exp) - 1
else: else:
result['muls'] -= 1 if result['muls'] > 0:
result['muls'] -= 1
result['divs'] += 1 result['divs'] += 1
result['muls'] += (-int(t.exp)) - 1 result['muls'] += (-int(t.exp)) - 1
elif sp.nsimplify(t.exp) == sp.Rational(1, 2): elif sp.nsimplify(t.exp) == sp.Rational(1, 2):
result['sqrts'] += 1 result['sqrts'] += 1
elif sp.nsimplify(t.exp) == -sp.Rational(1, 2):
result["sqrts"] += 1
result["divs"] += 1
else: else:
warnings.warn("Cannot handle exponent", t.exp, " of sp.Pow node") warnings.warn(f"Cannot handle exponent {t.exp} of sp.Pow node")
else: else:
warnings.warn("Counting operations: only integer exponents are supported in Pow, " warnings.warn("Counting operations: only integer exponents are supported in Pow, "
"counting will be inaccurate") "counting will be inaccurate")
...@@ -522,10 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]], ...@@ -522,10 +639,12 @@ def count_operations(term: Union[sp.Expr, List[sp.Expr]],
for child_term, condition in t.args: for child_term, condition in t.args:
visit(child_term) visit(child_term)
visit_children = False visit_children = False
elif isinstance(t, sp.Rel): elif isinstance(t, (sp.Rel, sp.UnevaluatedExpr)):
pass pass
elif isinstance(t, DivFunc):
result["divs"] += 1
else: else:
warnings.warn("Unknown sympy node of type " + str(t.func) + " counting will be inaccurate") warnings.warn(f"Unknown sympy node of type {str(t.func)} counting will be inaccurate")
if visit_children: if visit_children:
for a in t.args: for a in t.args:
......
File moved
import hashlib import hashlib
import pickle import pickle
import warnings import warnings
from collections import OrderedDict, defaultdict, namedtuple from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from types import MappingProxyType from types import MappingProxyType
from typing import Set
import numpy as np
import sympy as sp import sympy as sp
from sympy.logic.boolalg import Boolean
import pystencils as ps
import pystencils.astnodes as ast import pystencils.astnodes as ast
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.data_types import ( from pystencils.typing import (CastFunc, PointerType, StructType, TypedSymbol, get_base_type,
PointerType, StructType, TypedSymbol, cast_func, collate_types, create_type, get_base_type, ReinterpretCastFunc, get_next_parent_of_type, parents_of_type)
get_type_of_expression, pointer_arithmetic_func, reinterpret_cast_func) from pystencils.field import Field, FieldType
from pystencils.field import AbstractField, Field, FieldType from pystencils.typing import FieldPointerSymbol
from pystencils.kernelparameters import FieldPointerSymbol from pystencils.sympyextensions import fast_subs
from pystencils.simp.assignment_collection import AssignmentCollection from pystencils.simp.assignment_collection import AssignmentCollection
from pystencils.slicing import normalize_slice from pystencils.slicing import normalize_slice
from pystencils.integer_functions import int_div
class NestedScopes: class NestedScopes:
...@@ -99,6 +100,45 @@ def generic_visit(term, visitor): ...@@ -99,6 +100,45 @@ def generic_visit(term, visitor):
return visitor(term) return visitor(term)
def iterate_loops_by_depth(node, nesting_depth):
"""Iterate all LoopOverCoordinate nodes in the given AST of the specified nesting depth.
Args:
node: Root node of the abstract syntax tree
nesting_depth: Nesting depth of the loops the pragmas should be applied to.
Outermost loop has depth 0.
A depth of -1 indicates the innermost loops.
Returns: Iterable listing all loop nodes of given nesting depth.
"""
from pystencils.astnodes import LoopOverCoordinate
def _internal_default(node, nesting_depth):
isloop = isinstance(node, LoopOverCoordinate)
if nesting_depth < 0: # here, a negative value indicates end of descent
return
elif nesting_depth == 0 and isloop:
yield node
else:
next_depth = nesting_depth - 1 if isloop else nesting_depth
for arg in node.args:
yield from _internal_default(arg, next_depth)
def _internal_innermost(node):
if isinstance(node, LoopOverCoordinate) and node.is_innermost_loop:
yield node
else:
for arg in node.args:
yield from _internal_innermost(arg)
if nesting_depth >= 0:
yield from _internal_default(node, nesting_depth)
elif nesting_depth == -1:
yield from _internal_innermost(node)
else:
raise ValueError(f"Invalid nesting depth: {nesting_depth}. Choose a nonnegative number, or -1.")
def unify_shape_symbols(body, common_shape, fields): def unify_shape_symbols(body, common_shape, fields):
"""Replaces symbols for array sizes to ensure they are represented by the same unique symbol. """Replaces symbols for array sizes to ensure they are represented by the same unique symbol.
...@@ -123,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields): ...@@ -123,9 +163,10 @@ def unify_shape_symbols(body, common_shape, fields):
body.subs(substitutions) body.subs(substitutions)
def get_common_shape(field_set): def get_common_field(field_set):
"""Takes a set of pystencils Fields and returns their common spatial shape if it exists. Otherwise """Takes a set of pystencils Fields, checks if a common spatial shape exists and returns one
ValueError is raised""" representative field, that can be used for shape information etc. in the kernel creation.
If the fields have different shapes ValueError is raised"""
nr_of_fixed_shaped_fields = 0 nr_of_fixed_shaped_fields = 0
for f in field_set: for f in field_set:
if f.has_fixed_shape: if f.has_fixed_shape:
...@@ -135,7 +176,7 @@ def get_common_shape(field_set): ...@@ -135,7 +176,7 @@ def get_common_shape(field_set):
fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape]) fixed_field_names = ",".join([f.name for f in field_set if f.has_fixed_shape])
var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape]) var_field_names = ",".join([f.name for f in field_set if not f.has_fixed_shape])
msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n" msg = "Mixing fixed-shaped and variable-shape fields in a single kernel is not possible\n"
msg += "Variable shaped: %s \nFixed shaped: %s" % (var_field_names, fixed_field_names) msg += f"Variable shaped: {var_field_names} \nFixed shaped: {fixed_field_names}"
raise ValueError(msg) raise ValueError(msg)
shape_set = set([f.spatial_shape for f in field_set]) shape_set = set([f.spatial_shape for f in field_set])
...@@ -143,8 +184,9 @@ def get_common_shape(field_set): ...@@ -143,8 +184,9 @@ def get_common_shape(field_set):
if len(shape_set) != 1: if len(shape_set) != 1:
raise ValueError("Differently sized field accesses in loop body: " + str(shape_set)) raise ValueError("Differently sized field accesses in loop body: " + str(shape_set))
shape = list(sorted(shape_set, key=lambda e: str(e[0])))[0] # Sort the fields by their name to ensure that always the same field is returned
return shape reference_field = sorted(field_set, key=lambda e: str(e))[0]
return reference_field
def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None): def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_order=None):
...@@ -162,24 +204,38 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -162,24 +204,38 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
tuple of loop-node, ghost_layer_info tuple of loop-node, ghost_layer_info
""" """
# find correct ordering by inspecting participating FieldAccesses # find correct ordering by inspecting participating FieldAccesses
field_accesses = body.atoms(AbstractField.AbstractAccess) absolut_accesses_only = False
field_accesses = body.atoms(Field.Access)
field_accesses = {e for e in field_accesses if not e.is_absolute_access} field_accesses = {e for e in field_accesses if not e.is_absolute_access}
if len(field_accesses) == 0: # when kernel contains only absolute accesses
absolut_accesses_only = True
# exclude accesses to buffers from field_list, because buffers are treated separately # exclude accesses to buffers from field_list, because buffers are treated separately
field_list = [e.field for e in field_accesses if not FieldType.is_buffer(e.field)] field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field) or FieldType.is_custom(e.field))]
if len(field_list) == 0: # when kernel contains only custom fields
field_list = [e.field for e in field_accesses if not (FieldType.is_buffer(e.field))]
fields = set(field_list) fields = set(field_list)
if loop_order is None: if loop_order is None:
loop_order = get_optimal_loop_ordering(fields) loop_order = get_optimal_loop_ordering(fields)
shape = get_common_shape(fields) if absolut_accesses_only:
unify_shape_symbols(body, common_shape=shape, fields=fields) absolut_access_fields = {e.field for e in body.atoms(Field.Access)}
common_field = get_common_field(absolut_access_fields)
common_shape = common_field.spatial_shape
else:
common_field = get_common_field(fields)
common_shape = common_field.spatial_shape
unify_shape_symbols(body, common_shape=common_shape, fields=fields)
if iteration_slice is not None: if iteration_slice is not None:
iteration_slice = normalize_slice(iteration_slice, shape) iteration_slice = normalize_slice(iteration_slice, common_shape)
if ghost_layers is None: if ghost_layers is None:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses]) if absolut_accesses_only:
required_ghost_layers = 0
else:
required_ghost_layers = max([fa.required_ghost_layers for fa in field_accesses])
ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order) ghost_layers = [(required_ghost_layers, required_ghost_layers)] * len(loop_order)
if isinstance(ghost_layers, int): if isinstance(ghost_layers, int):
ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order) ghost_layers = [(ghost_layers, ghost_layers)] * len(loop_order)
...@@ -188,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -188,7 +244,7 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
for i, loop_coordinate in enumerate(reversed(loop_order)): for i, loop_coordinate in enumerate(reversed(loop_order)):
if iteration_slice is None: if iteration_slice is None:
begin = ghost_layers[loop_coordinate][0] begin = ghost_layers[loop_coordinate][0]
end = shape[loop_coordinate] - ghost_layers[loop_coordinate][1] end = common_shape[loop_coordinate] - ghost_layers[loop_coordinate][1]
new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1) new_loop = ast.LoopOverCoordinate(current_body, loop_coordinate, begin, end, 1)
current_body = ast.Block([new_loop]) current_body = ast.Block([new_loop])
else: else:
...@@ -205,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or ...@@ -205,6 +261,28 @@ def make_loop_over_domain(body, iteration_slice=None, ghost_layers=None, loop_or
return current_body, ghost_layers return current_body, ghost_layers
def get_common_indexed_element(indexed_elements: Set[sp.IndexedBase]) -> sp.IndexedBase:
assert len(indexed_elements) > 0, "indexed_elements can not be empty"
shape_set = {s.shape for s in indexed_elements}
if len(shape_set) != 1:
for shape in shape_set:
assert not isinstance(shape, int), "If indexed elements are used, they must all have the same shape"
return sorted(indexed_elements, key=lambda e: str(e))[0]
def add_outer_loop_over_indexed_elements(loop_node: ast.Block) -> ast.Block:
indexed_elements = loop_node.atoms(sp.Indexed)
if len(indexed_elements) == 0:
return loop_node
reference_element = get_common_indexed_element(indexed_elements)
index = reference_element.indices[0].atoms(TypedSymbol)
assert len(index) == 1, "index expressions must only contain one symbol representing the index"
new_loop = ast.LoopOverCoordinate(loop_node, 0, 0,
reference_element.shape[0], 1, custom_loop_ctr=index.pop())
return ast.Block([new_loop])
def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
r""" r"""
Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]` Addressing elements in structured arrays is done with :math:`ptr\left[ \sum_i c_i \cdot s_i \right]`
...@@ -239,7 +317,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr): ...@@ -239,7 +317,7 @@ def create_intermediate_base_pointer(field_access, coordinates, previous_ptr):
if coordinate_id < field.spatial_dimensions: if coordinate_id < field.spatial_dimensions:
offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id] offset += field.strides[coordinate_id] * field_access.offsets[coordinate_id]
if type(field_access.offsets[coordinate_id]) is int: if field_access.offsets[coordinate_id].is_Integer:
name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id]) name += "_%d%d" % (coordinate_id, field_access.offsets[coordinate_id])
else: else:
list_to_hash.append(field_access.offsets[coordinate_id]) list_to_hash.append(field_access.offsets[coordinate_id])
...@@ -300,6 +378,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime ...@@ -300,6 +378,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
if elem in specified_coordinates: if elem in specified_coordinates:
raise ValueError("Coordinate %d specified two times" % (elem,)) raise ValueError("Coordinate %d specified two times" % (elem,))
specified_coordinates.add(elem) specified_coordinates.add(elem)
for element in spec_group: for element in spec_group:
if type(element) is int: if type(element) is int:
add_new_element(element) add_new_element(element)
...@@ -320,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime ...@@ -320,7 +399,7 @@ def parse_base_pointer_info(base_pointer_specification, loop_order, spatial_dime
index = int(element[len("index"):]) index = int(element[len("index"):])
add_new_element(spatial_dimensions + index) add_new_element(spatial_dimensions + index)
else: else:
raise ValueError("Unknown specification %s" % (element,)) raise ValueError(f"Unknown specification {element}")
result.append(new_group) result.append(new_group)
...@@ -339,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -339,7 +418,7 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
ast_node: ast before any field accesses are resolved ast_node: ast before any field accesses are resolved
loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes) loop_counters: for CPU kernels: leave to default 'None' (can be determined from loop nodes)
for GPU kernels: list of 'loop counters' from inner to outer loop for GPU kernels: list of 'loop counters' from inner to outer loop
loop_iterations: number of iterations of each loop from inner to outer, for CPU kernels leave to default loop_iterations: iteration slice for each loop from inner to outer, for CPU kernels leave to default
Returns: Returns:
base buffer index - required by 'resolve_buffer_accesses' function base buffer index - required by 'resolve_buffer_accesses' function
...@@ -351,26 +430,46 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None): ...@@ -351,26 +430,46 @@ def get_base_buffer_index(ast_node, loop_counters=None, loop_iterations=None):
assert len(loops) == len(parents_of_innermost_loop) assert len(loops) == len(parents_of_innermost_loop)
assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop)) assert all(l1 is l2 for l1, l2 in zip(loops, parents_of_innermost_loop))
loop_iterations = [(l.stop - l.start) / l.step for l in loops] loop_counters = [loop.loop_counter_symbol for loop in loops]
loop_counters = [l.loop_counter_symbol for l in loops] loop_iterations = [slice(loop.start, loop.stop, loop.step) for loop in loops]
actual_sizes = list()
actual_steps = list()
for ctr, s in zip(loop_counters, loop_iterations):
if s.step != 1:
if (s.stop - s.start) % s.step == 0:
actual_sizes.append((s.stop - s.start) // s.step)
else:
actual_sizes.append(int_div((s.stop - s.start), s.step))
if (ctr - s.start) % s.step == 0:
actual_steps.append((ctr - s.start) // s.step)
else:
actual_steps.append(int_div((ctr - s.start), s.step))
else:
actual_sizes.append(s.stop - s.start)
actual_steps.append(ctr - s.start)
field_accesses = ast_node.atoms(AbstractField.AbstractAccess) field_accesses = ast_node.atoms(Field.Access)
buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)} buffer_accesses = {fa for fa in field_accesses if FieldType.is_buffer(fa.field)}
loop_counters = [v * len(buffer_accesses) for v in loop_counters] buffer_index_size = len(buffer_accesses)
base_buffer_index = loop_counters[0] base_buffer_index = actual_steps[0]
stride = 1 actual_stride = 1
for idx, var in enumerate(loop_counters[1:]): for idx, actual_step in enumerate(actual_steps[1:]):
cur_stride = loop_iterations[idx] cur_stride = actual_sizes[idx]
stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride actual_stride *= int(cur_stride) if isinstance(cur_stride, float) else cur_stride
base_buffer_index += var * stride base_buffer_index += actual_stride * actual_step
return base_buffer_index return base_buffer_index * buffer_index_size
def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=set()): def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=None):
if read_only_field_names is None:
read_only_field_names = set()
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
# Do not apply transformation if field is not a buffer # Do not apply transformation if field is not a buffer
...@@ -413,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s ...@@ -413,7 +512,7 @@ def resolve_buffer_accesses(ast_node, base_buffer_index, read_only_field_names=s
return visit_node(ast_node) return visit_node(ast_node)
def resolve_field_accesses(ast_node, read_only_field_names=set(), def resolve_field_accesses(ast_node, read_only_field_names=None,
field_to_base_pointer_info=MappingProxyType({}), field_to_base_pointer_info=MappingProxyType({}),
field_to_fixed_coordinates=MappingProxyType({})): field_to_fixed_coordinates=MappingProxyType({})):
""" """
...@@ -430,11 +529,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -430,11 +529,13 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
Returns Returns
transformed AST transformed AST
""" """
if read_only_field_names is None:
read_only_field_names = set()
field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0])) field_to_base_pointer_info = OrderedDict(sorted(field_to_base_pointer_info.items(), key=lambda pair: pair[0]))
field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0])) field_to_fixed_coordinates = OrderedDict(sorted(field_to_fixed_coordinates.items(), key=lambda pair: pair[0]))
def visit_sympy_expr(expr, enclosing_block, sympy_assignment): def visit_sympy_expr(expr, enclosing_block, sympy_assignment):
if isinstance(expr, AbstractField.AbstractAccess): if isinstance(expr, Field.Access):
field_access = expr field_access = expr
field = field_access.field field = field_access.field
...@@ -452,7 +553,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -452,7 +553,10 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
else: else:
base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))] base_pointer_info = [list(range(field.index_dimensions + field.spatial_dimensions))]
field_ptr = FieldPointerSymbol(field.name, field.dtype, const=field.name in read_only_field_names) field_ptr = FieldPointerSymbol(
field.name,
field.dtype,
const=field.name in read_only_field_names)
def create_coordinate_dict(group_param): def create_coordinate_dict(group_param):
coordinates = {} coordinates = {}
...@@ -473,6 +577,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -473,6 +577,8 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if isinstance(field.dtype, StructType): if isinstance(field.dtype, StructType):
assert field.index_dimensions == 1 assert field.index_dimensions == 1
accessed_field_name = field_access.index[0] accessed_field_name = field_access.index[0]
if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
assert isinstance(accessed_field_name, str) assert isinstance(accessed_field_name, str)
coordinates[e] = field.dtype.get_element_offset(accessed_field_name) coordinates[e] = field.dtype.get_element_offset(accessed_field_name)
else: else:
...@@ -486,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -486,7 +592,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
coord_dict = create_coordinate_dict(group) coord_dict = create_coordinate_dict(group)
new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer)
if new_ptr not in enclosing_block.symbols_defined: if new_ptr not in enclosing_block.symbols_defined:
new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False)
enclosing_block.insert_before(new_assignment, sympy_assignment) enclosing_block.insert_before(new_assignment, sympy_assignment)
last_pointer = new_ptr last_pointer = new_ptr
...@@ -496,15 +602,21 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -496,15 +602,21 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
field_access.offsets, field_access.index) field_access.offsets, field_access.index)
if isinstance(get_base_type(field_access.field.dtype), StructType): if isinstance(get_base_type(field_access.field.dtype), StructType):
new_type = field_access.field.dtype.get_element_type(field_access.index[0]) accessed_field_name = field_access.index[0]
result = reinterpret_cast_func(result, new_type) if isinstance(accessed_field_name, sp.Symbol):
accessed_field_name = accessed_field_name.name
new_type = field_access.field.dtype.get_element_type(accessed_field_name)
result = ReinterpretCastFunc(result, new_type)
return visit_sympy_expr(result, enclosing_block, sympy_assignment) return visit_sympy_expr(result, enclosing_block, sympy_assignment)
else: else:
if isinstance(expr, ast.ResolvedFieldAccess): if isinstance(expr, ast.ResolvedFieldAccess):
return expr return expr
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args] if hasattr(expr, 'args'):
new_args = [visit_sympy_expr(e, enclosing_block, sympy_assignment) for e in expr.args]
else:
new_args = []
kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {} kwargs = {'evaluate': False} if type(expr) in (sp.Add, sp.Mul, sp.Piecewise) else {}
return expr.func(*new_args, **kwargs) if new_args else expr return expr.func(*new_args, **kwargs) if new_args else expr
...@@ -522,7 +634,9 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(), ...@@ -522,7 +634,9 @@ def resolve_field_accesses(ast_node, read_only_field_names=set(),
if sub_ast.false_block: if sub_ast.false_block:
visit_node(sub_ast.false_block) visit_node(sub_ast.false_block)
else: else:
for i, a in enumerate(sub_ast.args): if isinstance(sub_ast, (bool, int, float)):
return
for a in sub_ast.args:
visit_node(a) visit_node(a)
return visit_node(ast_node) return visit_node(ast_node)
...@@ -542,21 +656,65 @@ def move_constants_before_loop(ast_node): ...@@ -542,21 +656,65 @@ def move_constants_before_loop(ast_node):
""" """
assert isinstance(node.parent, ast.Block) assert isinstance(node.parent, ast.Block)
def modifies_or_declares(node: ast.Node, symbol_names: Set[str]) -> bool:
if isinstance(node, (ps.Assignment, ast.SympyAssignment)):
if isinstance(node.lhs, ast.ResolvedFieldAccess):
return node.lhs.typed_symbol.name in symbol_names
else:
return node.lhs.name in symbol_names
elif isinstance(node, ast.Block):
for arg in node.args:
if isinstance(arg, ast.SympyAssignment) and arg.is_declaration:
continue
if modifies_or_declares(arg, symbol_names):
return True
return False
elif isinstance(node, ast.LoopOverCoordinate):
return modifies_or_declares(node.body, symbol_names)
elif isinstance(node, ast.Conditional):
return (
modifies_or_declares(node.true_block, symbol_names)
or (node.false_block and modifies_or_declares(node.false_block, symbol_names))
)
elif isinstance(node, ast.KernelFunction):
return False
else:
defs = {s.name for s in node.symbols_defined}
return bool(symbol_names.intersection(defs))
dependencies = {s.name for s in node.undefined_symbols}
last_block = node.parent last_block = node.parent
last_block_child = node last_block_child = node
element = node.parent element = node.parent
prev_element = node prev_element = node
while element: while element:
if isinstance(element, ast.Block): if isinstance(element, (ast.Conditional, ast.KernelFunction)):
# Never move out of Conditionals or KernelFunctions.
break
elif isinstance(element, ast.Block):
last_block = element last_block = element
last_block_child = prev_element last_block_child = prev_element
if isinstance(element, ast.Conditional): if any(modifies_or_declares(sibling, dependencies) for sibling in element.args):
break # The node depends on one of the statements in this block.
# Do not move further out.
break
elif isinstance(element, ast.LoopOverCoordinate):
if element.loop_counter_symbol.name in dependencies:
# The node depends on the loop counter.
# Do not move out of this loop.
break
else: else:
critical_symbols = element.symbols_defined raise NotImplementedError(f'Due to defensive programming we handle only specific expressions.\n'
if node.undefined_symbols.intersection(critical_symbols): f'The expression {element} of type {type(element)} is not known yet.')
break
# No dependencies to symbols defined/modified within the current element.
# We can move the node up one level and in front of the current element.
prev_element = element prev_element = element
element = element.parent element = element.parent
return last_block, last_block_child return last_block, last_block_child
...@@ -580,13 +738,7 @@ def move_constants_before_loop(ast_node): ...@@ -580,13 +738,7 @@ def move_constants_before_loop(ast_node):
get_blocks(ast_node, all_blocks) get_blocks(ast_node, all_blocks)
for block in all_blocks: for block in all_blocks:
children = block.take_child_nodes() children = block.take_child_nodes()
# Every time a symbol can be replaced in the current block because the assignment
# was found in a parent block, but with a different lhs symbol (same rhs)
# the outer symbol is inserted here as key.
substitute_variables = {}
for child in children: for child in children:
# Before traversing the next child, all symbols are substituted first.
child.subs(substitute_variables)
if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments if not isinstance(child, ast.SympyAssignment): # only move SympyAssignments
block.append(child) block.append(child)
...@@ -602,23 +754,21 @@ def move_constants_before_loop(ast_node): ...@@ -602,23 +754,21 @@ def move_constants_before_loop(ast_node):
exists_already = False exists_already = False
if not exists_already: if not exists_already:
rhs_identical = check_if_assignment_already_in_block(child, target, True) target.insert_before(child, child_to_insert_before)
if rhs_identical:
# there is already an assignment out there with the same rhs
# -> replace all lhs symbols in this block with the lhs of the outer assignment
# -> remove the local assignment (do not re-append child to the former block)
substitute_variables[child.lhs] = rhs_identical.lhs
else:
target.insert_before(child, child_to_insert_before)
elif exists_already and exists_already.rhs == child.rhs: elif exists_already and exists_already.rhs == child.rhs:
pass if target.args.index(exists_already) > target.args.index(child_to_insert_before):
assert target.args.count(exists_already) == 1
assert target.args.count(child_to_insert_before) == 1
target.args.remove(exists_already)
target.insert_before(exists_already, child_to_insert_before)
else: else:
# this variable already exists in outer block, but with different rhs # this variable already exists in outer block, but with different rhs
# -> symbol has to be renamed # -> symbol has to be renamed
assert isinstance(child.lhs, TypedSymbol) assert isinstance(child.lhs, TypedSymbol)
new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype) new_symbol = TypedSymbol(sp.Dummy().name, child.lhs.dtype)
target.insert_before(ast.SympyAssignment(new_symbol, child.rhs), child_to_insert_before) target.insert_before(ast.SympyAssignment(new_symbol, child.rhs, is_const=child.is_const),
substitute_variables[child.lhs] = new_symbol child_to_insert_before)
block.append(ast.SympyAssignment(child.lhs, new_symbol, is_const=child.is_const))
def split_inner_loop(ast_node: ast.Node, symbol_groups): def split_inner_loop(ast_node: ast.Node, symbol_groups):
...@@ -632,16 +782,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -632,16 +782,16 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
and which no symbol in a symbol group depends on, are not updated! and which no symbol in a symbol group depends on, are not updated!
""" """
all_loops = ast_node.atoms(ast.LoopOverCoordinate) all_loops = ast_node.atoms(ast.LoopOverCoordinate)
inner_loop = [l for l in all_loops if l.is_innermost_loop] inner_loop = [loop for loop in all_loops if loop.is_innermost_loop]
assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?" assert len(inner_loop) == 1, "Error in AST: multiple innermost loops. Was split transformation already called?"
inner_loop = inner_loop[0] inner_loop = inner_loop[0]
assert type(inner_loop.body) is ast.Block assert type(inner_loop.body) is ast.Block
outer_loop = [l for l in all_loops if l.is_outermost_loop] outer_loop = [loop for loop in all_loops if loop.is_outermost_loop]
assert len(outer_loop) == 1, "Error in AST, multiple outermost loops." assert len(outer_loop) == 1, "Error in AST, multiple outermost loops."
outer_loop = outer_loop[0] outer_loop = outer_loop[0]
symbols_with_temporary_array = OrderedDict() symbols_with_temporary_array = OrderedDict()
assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args) assignment_map = OrderedDict((a.lhs, a) for a in inner_loop.body.args if hasattr(a, 'lhs'))
assignment_groups = [] assignment_groups = []
for symbol_group in symbol_groups: for symbol_group in symbol_groups:
...@@ -655,32 +805,36 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups): ...@@ -655,32 +805,36 @@ def split_inner_loop(ast_node: ast.Node, symbol_groups):
if s in assignment_map: # if there is no assignment inside the loop body it is independent already if s in assignment_map: # if there is no assignment inside the loop body it is independent already
for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol): for new_symbol in assignment_map[s].rhs.atoms(sp.Symbol):
if not isinstance(new_symbol, AbstractField.AbstractAccess) and \ if not isinstance(new_symbol, Field.Access) and \
new_symbol not in symbols_with_temporary_array: new_symbol not in symbols_with_temporary_array:
symbols_to_process.append(new_symbol) symbols_to_process.append(new_symbol)
symbols_resolved.add(s) symbols_resolved.add(s)
for symbol in symbol_group: for symbol in symbol_group:
if not isinstance(symbol, AbstractField.AbstractAccess): if not isinstance(symbol, Field.Access):
assert type(symbol) is TypedSymbol assert type(symbol) is TypedSymbol
new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype)) new_ts = TypedSymbol(symbol.name, PointerType(symbol.dtype))
symbols_with_temporary_array[symbol] = sp.IndexedBase(new_ts, symbols_with_temporary_array[symbol] = sp.IndexedBase(
shape=(1,))[inner_loop.loop_counter_symbol] new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
assignment_group = [] assignment_group = []
for assignment in inner_loop.body.args: for assignment in inner_loop.body.args:
if assignment.lhs in symbols_resolved: if assignment.lhs in symbols_resolved:
new_rhs = assignment.rhs.subs(symbols_with_temporary_array.items()) # use fast_subs here because it checks if multiplications should be evaluated or not
if not isinstance(assignment.lhs, AbstractField.AbstractAccess) and assignment.lhs in symbol_group: new_rhs = fast_subs(assignment.rhs, symbols_with_temporary_array)
if not isinstance(assignment.lhs, Field.Access) and assignment.lhs in symbol_group:
assert type(assignment.lhs) is TypedSymbol assert type(assignment.lhs) is TypedSymbol
new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype)) new_ts = TypedSymbol(assignment.lhs.name, PointerType(assignment.lhs.dtype))
new_lhs = sp.IndexedBase(new_ts, shape=(1,))[inner_loop.loop_counter_symbol] new_lhs = sp.IndexedBase(new_ts, shape=(1, ))[inner_loop.loop_counter_symbol]
else: else:
new_lhs = assignment.lhs new_lhs = assignment.lhs
assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs)) assignment_group.append(ast.SympyAssignment(new_lhs, new_rhs))
assignment_groups.append(assignment_group) assignment_groups.append(assignment_group)
new_loops = [inner_loop.new_loop_with_different_body(ast.Block(group)) for group in assignment_groups] new_loops = [
inner_loop.new_loop_with_different_body(ast.Block(group))
for group in assignment_groups
]
inner_loop.parent.replace(inner_loop, ast.Block(new_loops)) inner_loop.parent.replace(inner_loop, ast.Block(new_loops))
for tmp_array in symbols_with_temporary_array: for tmp_array in symbols_with_temporary_array:
...@@ -697,7 +851,8 @@ def cut_loop(loop_node, cutting_points): ...@@ -697,7 +851,8 @@ def cut_loop(loop_node, cutting_points):
One loop is transformed into len(cuttingPoints)+1 new loops that range from One loop is transformed into len(cuttingPoints)+1 new loops that range from
old_begin to cutting_points[1], ..., cutting_points[-1] to old_end old_begin to cutting_points[1], ..., cutting_points[-1] to old_end
Modifies the ast in place Modifies the ast in place. Note Issue #5783 of SymPy. Deepcopy will evaluate mul
https://github.com/sympy/sympy/issues/5783
Returns: Returns:
list of new loop nodes list of new loop nodes
...@@ -715,8 +870,9 @@ def cut_loop(loop_node, cutting_points): ...@@ -715,8 +870,9 @@ def cut_loop(loop_node, cutting_points):
elif new_end - new_start == 0: elif new_end - new_start == 0:
pass pass
else: else:
new_loop = ast.LoopOverCoordinate(deepcopy(loop_node.body), loop_node.coordinate_to_loop_over, new_loop = ast.LoopOverCoordinate(
new_start, new_end, loop_node.step) deepcopy(loop_node.body), loop_node.coordinate_to_loop_over,
new_start, new_end, loop_node.step)
new_loops.append(new_loop) new_loops.append(new_loop)
new_start = new_end new_start = new_end
loop_node.parent.replace(loop_node, new_loops) loop_node.parent.replace(loop_node, new_loops)
...@@ -734,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa ...@@ -734,11 +890,16 @@ def simplify_conditionals(node: ast.Node, loop_counter_simplification: bool = Fa
This analysis needs the integer set library (ISL) islpy, so it is not done by This analysis needs the integer set library (ISL) islpy, so it is not done by
default. default.
""" """
from sympy.codegen.rewriting import ReplaceOptim, optimize
remove_casts = ReplaceOptim(lambda e: isinstance(e, CastFunc), lambda p: p.expr)
for conditional in node.atoms(ast.Conditional): for conditional in node.atoms(ast.Conditional):
conditional.condition_expr = sp.simplify(conditional.condition_expr) # TODO simplify conditional before the type system! Casts make it very hard here
if conditional.condition_expr == sp.true: condition_expression = optimize(conditional.condition_expr, [remove_casts])
condition_expression = sp.simplify(condition_expression)
if condition_expression == sp.true:
conditional.parent.replace(conditional, [conditional.true_block]) conditional.parent.replace(conditional, [conditional.true_block])
elif conditional.condition_expr == sp.false: elif condition_expression == sp.false:
conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else []) conditional.parent.replace(conditional, [conditional.false_block] if conditional.false_block else [])
elif loop_counter_simplification: elif loop_counter_simplification:
try: try:
...@@ -764,247 +925,19 @@ def cleanup_blocks(node: ast.Node) -> None: ...@@ -764,247 +925,19 @@ def cleanup_blocks(node: ast.Node) -> None:
cleanup_blocks(a) cleanup_blocks(a)
class KernelConstraintsCheck: def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction, include_first=True) -> None:
"""Checks if the input to create_kernel is valid. """Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last or
first and last element"""
Test the following conditions:
- SSA Form for pure symbols:
- Every pure symbol may occur only once as left-hand-side of an assignment
- Every pure symbol that is read, may not be written to later
- Independence / Parallelization condition:
- a field that is written may only be read at exact the same spatial position
(Pure symbols are symbols that are not Field.Accesses)
"""
FieldAndIndex = namedtuple('FieldAndIndex', ['field', 'index'])
def __init__(self, type_for_symbol, check_independence_condition):
self._type_for_symbol = type_for_symbol
self.scopes = NestedScopes()
self._field_writes = defaultdict(set)
self.fields_read = set()
self.check_independence_condition = check_independence_condition
def process_assignment(self, assignment):
# for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1
new_rhs = self.process_expression(assignment.rhs)
new_lhs = self._process_lhs(assignment.lhs)
return ast.SympyAssignment(new_lhs, new_rhs)
def process_expression(self, rhs, type_constants=True):
self._update_accesses_rhs(rhs)
if isinstance(rhs, AbstractField.AbstractAccess):
self.fields_read.add(rhs.field)
self.fields_read.update(rhs.indirect_addressing_fields)
return rhs
elif isinstance(rhs, TypedSymbol):
return rhs
elif isinstance(rhs, sp.Symbol):
return TypedSymbol(rhs.name, self._type_for_symbol[rhs.name])
elif type_constants and isinstance(rhs, np.generic):
return cast_func(rhs, create_type(rhs.dtype))
elif type_constants and isinstance(rhs, sp.Number):
return cast_func(rhs, create_type(self._type_for_symbol['_constant']))
elif isinstance(rhs, sp.Mul):
new_args = [self.process_expression(arg, type_constants) if arg not in (-1, 1) else arg for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
elif isinstance(rhs, sp.Indexed):
return rhs
elif isinstance(rhs, cast_func):
return cast_func(self.process_expression(rhs.args[0], type_constants=False), rhs.dtype)
else:
if isinstance(rhs, sp.Pow):
# don't process exponents -> they should remain integers
return sp.Pow(self.process_expression(rhs.args[0], type_constants), rhs.args[1])
else:
new_args = [self.process_expression(arg, type_constants) for arg in rhs.args]
return rhs.func(*new_args) if new_args else rhs
@property
def fields_written(self):
return set(k.field for k, v in self._field_writes.items() if len(v))
def _process_lhs(self, lhs):
assert isinstance(lhs, sp.Symbol)
self._update_accesses_lhs(lhs)
if not isinstance(lhs, AbstractField.AbstractAccess) and not isinstance(lhs, TypedSymbol):
return TypedSymbol(lhs.name, self._type_for_symbol[lhs.name])
else:
return lhs
def _update_accesses_lhs(self, lhs):
if isinstance(lhs, AbstractField.AbstractAccess):
fai = self.FieldAndIndex(lhs.field, lhs.index)
self._field_writes[fai].add(lhs.offsets)
if len(self._field_writes[fai]) > 1:
raise ValueError("Field {} is written at two different locations".format(lhs.field.name))
elif isinstance(lhs, sp.Symbol):
if self.scopes.is_defined_locally(lhs):
raise ValueError("Assignments not in SSA form, multiple assignments to {}".format(lhs.name))
if lhs in self.scopes.free_parameters:
raise ValueError("Symbol {} is written, after it has been read".format(lhs.name))
self.scopes.define_symbol(lhs)
def _update_accesses_rhs(self, rhs):
if isinstance(rhs, AbstractField.AbstractAccess) and self.check_independence_condition:
writes = self._field_writes[self.FieldAndIndex(rhs.field, rhs.index)]
for write_offset in writes:
assert len(writes) == 1
if write_offset != rhs.offsets:
raise ValueError("Violation of loop independence condition. Field "
"{} is read at {} and written at {}".format(rhs.field, rhs.offsets, write_offset))
self.fields_read.add(rhs.field)
elif isinstance(rhs, sp.Symbol):
self.scopes.access_symbol(rhs)
def add_types(eqs, type_for_symbol, check_independence_condition):
"""Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`.
Additionally returns sets of all fields which are read/written
Args:
eqs: list of equations
type_for_symbol: dict mapping symbol names to types. Types are strings of C types like 'int' or 'double'
check_independence_condition: check that loop iterations are independent - this has to be skipped for indexed
kernels
Returns:
``fields_read, fields_written, typed_equations`` set of read fields, set of written fields,
list of equations where symbols have been replaced by typed symbols
"""
if isinstance(type_for_symbol, str) or not hasattr(type_for_symbol, '__getitem__'):
type_for_symbol = typing_from_sympy_inspection(eqs, type_for_symbol)
check = KernelConstraintsCheck(type_for_symbol, check_independence_condition)
def visit(obj):
if isinstance(obj, list) or isinstance(obj, tuple):
return [visit(e) for e in obj]
if isinstance(obj, sp.Eq) or isinstance(obj, ast.SympyAssignment) or isinstance(obj, Assignment):
return check.process_assignment(obj)
elif isinstance(obj, ast.Conditional):
check.scopes.push()
false_block = None if obj.false_block is None else visit(obj.false_block)
result = ast.Conditional(check.process_expression(obj.condition_expr, type_constants=False),
true_block=visit(obj.true_block), false_block=false_block)
check.scopes.pop()
return result
elif isinstance(obj, ast.Block):
check.scopes.push()
result = ast.Block([visit(e) for e in obj.args])
check.scopes.pop()
return result
elif isinstance(obj, ast.Node) and not isinstance(obj, ast.LoopOverCoordinate):
return obj
else:
raise ValueError("Invalid object in kernel " + str(type(obj)))
typed_equations = visit(eqs)
return check.fields_read, check.fields_written, typed_equations
def insert_casts(node):
"""Checks the types and inserts casts and pointer arithmetic where necessary.
Args:
node: the head node of the ast
Returns:
modified AST
"""
def cast(zipped_args_types, target_dtype):
"""
Adds casts to the arguments if their type differs from the target type
:param zipped_args_types: a zipped list of args and types
:param target_dtype: The target data type
:return: args with possible casts
"""
casted_args = []
for argument, data_type in zipped_args_types:
if data_type.numpy_dtype != target_dtype.numpy_dtype: # ignoring const
casted_args.append(cast_func(argument, target_dtype))
else:
casted_args.append(argument)
return casted_args
def pointer_arithmetic(expr_args):
"""
Creates a valid pointer arithmetic function
:param expr_args: Arguments of the add expression
:return: pointer_arithmetic_func
"""
pointer = None
new_args = []
for arg, data_type in expr_args:
if data_type.func is PointerType:
assert pointer is None
pointer = arg
for arg, data_type in expr_args:
if arg != pointer:
assert data_type.is_int() or data_type.is_uint()
new_args.append(arg)
new_args = sp.Add(*new_args) if len(new_args) > 0 else new_args
return pointer_arithmetic_func(pointer, new_args)
if isinstance(node, sp.AtomicExpr) or isinstance(node, cast_func):
return node
args = []
for arg in node.args:
args.append(insert_casts(arg))
# TODO indexed, LoopOverCoordinate
if node.func in (sp.Add, sp.Mul, sp.Or, sp.And, sp.Pow, sp.Eq, sp.Ne, sp.Lt, sp.Le, sp.Gt, sp.Ge):
# TODO optimize pow, don't cast integer on double
types = [get_type_of_expression(arg) for arg in args]
assert len(types) > 0
target = collate_types(types)
zipped = list(zip(args, types))
if target.func is PointerType:
assert node.func is sp.Add
return pointer_arithmetic(zipped)
else:
return node.func(*cast(zipped, target))
elif node.func is ast.SympyAssignment:
lhs = args[0]
rhs = args[1]
target = get_type_of_expression(lhs)
if target.func is PointerType:
return node.func(*args) # TODO fix, not complete
else:
return node.func(lhs, *cast([(rhs, get_type_of_expression(rhs))], target))
elif node.func is ast.ResolvedFieldAccess:
return node
elif node.func is ast.Block:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is ast.LoopOverCoordinate:
for old_arg, new_arg in zip(node.args, args):
node.replace(old_arg, new_arg)
return node
elif node.func is sp.Piecewise:
expressions = [expr for (expr, _) in args]
types = [get_type_of_expression(expr) for expr in expressions]
target = collate_types(types)
zipped = list(zip(expressions, types))
casted_expressions = cast(zipped, target)
args = [arg.func(*[expr, arg.cond]) for (arg, expr) in zip(args, casted_expressions)]
return node.func(*args)
def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -> None:
"""Removes conditionals of a kernel that iterates over staggered positions by splitting the loops at last element"""
all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop] all_inner_loops = [l for l in function_node.atoms(ast.LoopOverCoordinate) if l.is_innermost_loop]
assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop" assert len(all_inner_loops) == 1, "Transformation works only on kernels with exactly one inner loop"
inner_loop = all_inner_loops.pop() inner_loop = all_inner_loops.pop()
for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True): for loop in parents_of_type(inner_loop, ast.LoopOverCoordinate, include_current=True):
cut_loop(loop, [loop.stop - 1]) if include_first:
cut_loop(loop, [loop.start + 1, loop.stop - 1])
else:
cut_loop(loop, [loop.stop - 1])
simplify_conditionals(function_node.body, loop_counter_simplification=True) simplify_conditionals(function_node.body, loop_counter_simplification=True)
cleanup_blocks(function_node.body) cleanup_blocks(function_node.body)
...@@ -1014,61 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) - ...@@ -1014,61 +947,6 @@ def remove_conditionals_in_staggered_kernel(function_node: ast.KernelFunction) -
# --------------------------------------- Helper Functions ------------------------------------------------------------- # --------------------------------------- Helper Functions -------------------------------------------------------------
def typing_from_sympy_inspection(eqs, default_type="double"):
"""
Creates a default symbol name to type mapping.
If a sympy Boolean is assigned to a symbol it is assumed to be 'bool' otherwise the default type, usually ('double')
Args:
eqs: list of equations
default_type: the type for non-boolean symbols
Returns:
dictionary, mapping symbol name to type
"""
result = defaultdict(lambda: default_type)
for eq in eqs:
if isinstance(eq, ast.Conditional):
result.update(typing_from_sympy_inspection(eq.true_block.args))
if eq.false_block:
result.update(typing_from_sympy_inspection(eq.false_block.args))
elif isinstance(eq, ast.Node) and not isinstance(eq, ast.SympyAssignment):
continue
else:
from pystencils.cpu.vectorization import vec_all, vec_any
if isinstance(eq.rhs, vec_all) or isinstance(eq.rhs, vec_any):
result[eq.lhs.name] = "bool"
# problematic case here is when rhs is a symbol: then it is impossible to decide here without
# further information what type the left hand side is - default fallback is the dict value then
if isinstance(eq.rhs, Boolean) and not isinstance(eq.rhs, sp.Symbol):
result[eq.lhs.name] = "bool"
return result
def get_next_parent_of_type(node, parent_type):
"""Returns the next parent node of given type or None, if root is reached.
Traverses the AST nodes parents until a parent of given type was found.
If no such parent is found, None is returned
"""
parent = node.parent
while parent is not None:
if isinstance(parent, parent_type):
return parent
parent = parent.parent
return None
def parents_of_type(node, parent_type, include_current=False):
"""Generator for all parent nodes of given type"""
parent = node if include_current else node.parent
while parent is not None:
if isinstance(parent, parent_type):
yield parent
parent = parent.parent
def get_optimal_loop_ordering(fields): def get_optimal_loop_ordering(fields):
""" """
Determines the optimal loop order for a given set of fields. Determines the optimal loop order for a given set of fields.
...@@ -1084,13 +962,17 @@ def get_optimal_loop_ordering(fields): ...@@ -1084,13 +962,17 @@ def get_optimal_loop_ordering(fields):
ref_field = next(iter(fields)) ref_field = next(iter(fields))
for field in fields: for field in fields:
if field.spatial_dimensions != ref_field.spatial_dimensions: if field.spatial_dimensions != ref_field.spatial_dimensions:
raise ValueError("All fields have to have the same number of spatial dimensions. Spatial field dimensions: " raise ValueError(
+ str({f.name: f.spatial_shape for f in fields})) "All fields have to have the same number of spatial dimensions. Spatial field dimensions: "
+ str({f.name: f.spatial_shape
for f in fields}))
layouts = set([field.layout for field in fields]) layouts = set([field.layout for field in fields])
if len(layouts) > 1: if len(layouts) > 1:
raise ValueError("Due to different layout of the fields no optimal loop ordering exists " raise ValueError(
+ str({f.name: f.layout for f in fields})) "Due to different layout of the fields no optimal loop ordering exists "
+ str({f.name: f.layout
for f in fields}))
layout = list(layouts)[0] layout = list(layouts)[0]
return list(layout) return list(layout)
...@@ -1110,13 +992,13 @@ def get_loop_hierarchy(ast_node): ...@@ -1110,13 +992,13 @@ def get_loop_hierarchy(ast_node):
return reversed(result) return reversed(result)
def get_loop_counter_symbol_hierarchy(astNode): def get_loop_counter_symbol_hierarchy(ast_node):
"""Determines the loop counter symbols around a given AST node. """Determines the loop counter symbols around a given AST node.
:param astNode: the AST node :param ast_node: the AST node
:return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop :return: list of loop counter symbols, where the first list entry is the symbol of the innermost loop
""" """
result = [] result = []
node = astNode node = ast_node
while node is not None: while node is not None:
node = get_next_parent_of_type(node, ast.LoopOverCoordinate) node = get_next_parent_of_type(node, ast.LoopOverCoordinate)
if node: if node:
...@@ -1135,7 +1017,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: ...@@ -1135,7 +1017,9 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
""" """
inner_loops = [] inner_loops = []
inner_loop_counters = set() inner_loop_counters = set()
for loop in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment): for loop in filtered_tree_iteration(ast_node,
ast.LoopOverCoordinate,
stop_type=ast.SympyAssignment):
if loop.is_innermost_loop: if loop.is_innermost_loop:
inner_loops.append(loop) inner_loops.append(loop)
inner_loop_counters.add(loop.coordinate_to_loop_over) inner_loop_counters.add(loop.coordinate_to_loop_over)
...@@ -1146,8 +1030,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None: ...@@ -1146,8 +1030,10 @@ def replace_inner_stride_with_one(ast_node: ast.KernelFunction) -> None:
inner_loop_counter = inner_loop_counters.pop() inner_loop_counter = inner_loop_counters.pop()
parameters = ast_node.get_parameters() parameters = ast_node.get_parameters()
stride_params = [p.symbol for p in parameters stride_params = [
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter] p.symbol for p in parameters
if p.is_field_stride and p.symbol.coordinate == inner_loop_counter
]
subs_dict = {stride_param: 1 for stride_param in stride_params} subs_dict = {stride_param: 1 for stride_param in stride_params}
if subs_dict: if subs_dict:
ast_node.subs(subs_dict) ast_node.subs(subs_dict)
...@@ -1158,17 +1044,23 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1158,17 +1044,23 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
Args: Args:
ast_node: kernel function node before vectorization transformation has been applied ast_node: kernel function node before vectorization transformation has been applied
block_size: sequence defining block size in x, y, (z) direction block_size: sequence defining block size in x, y, (z) direction.
If chosen as zero the direction will not be used for blocking.
Returns: Returns:
number of dimensions blocked number of dimensions blocked
""" """
loops = [l for l in filtered_tree_iteration(ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)] loops = [
l for l in filtered_tree_iteration(
ast_node, ast.LoopOverCoordinate, stop_type=ast.SympyAssignment)
]
body = ast_node.body body = ast_node.body
coordinates = [] coordinates = []
coordinates_taken_into_account = 0
loop_starts = {} loop_starts = {}
loop_stops = {} loop_stops = {}
for loop in loops: for loop in loops:
coord = loop.coordinate_to_loop_over coord = loop.coordinate_to_loop_over
if coord not in coordinates: if coord not in coordinates:
...@@ -1177,26 +1069,36 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int: ...@@ -1177,26 +1069,36 @@ def loop_blocking(ast_node: ast.KernelFunction, block_size) -> int:
loop_stops[coord] = loop.stop loop_stops[coord] = loop.stop
else: else:
assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \ assert loop.start == loop_starts[coord] and loop.stop == loop_stops[coord], \
"Multiple loops over coordinate {} with different loop bounds".format(coord) f"Multiple loops over coordinate {coord} with different loop bounds"
# Create the outer loops that iterate over the blocks # Create the outer loops that iterate over the blocks
outer_loop = None outer_loop = None
for coord in reversed(coordinates): for coord in reversed(coordinates):
if block_size[coord] == 0:
continue
coordinates_taken_into_account += 1
body = ast.Block([outer_loop]) if outer_loop else body body = ast.Block([outer_loop]) if outer_loop else body
outer_loop = ast.LoopOverCoordinate(body, coord, loop_starts[coord], loop_stops[coord], outer_loop = ast.LoopOverCoordinate(body,
step=block_size[coord], is_block_loop=True) coord,
loop_starts[coord],
loop_stops[coord],
step=block_size[coord],
is_block_loop=True)
ast_node.body = ast.Block([outer_loop]) ast_node.body = ast.Block([outer_loop])
# modify the existing loops to only iterate within one block # modify the existing loops to only iterate within one block
for inner_loop in loops: for inner_loop in loops:
coord = inner_loop.coordinate_to_loop_over coord = inner_loop.coordinate_to_loop_over
if block_size[coord] == 0:
continue
block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord) block_ctr = ast.LoopOverCoordinate.get_block_loop_counter_symbol(coord)
loop_range = inner_loop.stop - inner_loop.start loop_range = inner_loop.stop - inner_loop.start
if sp.sympify(loop_range).is_number and loop_range % block_size[coord] == 0: if sp.sympify(
loop_range).is_number and loop_range % block_size[coord] == 0:
stop = block_ctr + block_size[coord] stop = block_ctr + block_size[coord]
else: else:
stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord]) stop = sp.Min(inner_loop.stop, block_ctr + block_size[coord])
inner_loop.start = block_ctr inner_loop.start = block_ctr
inner_loop.stop = stop inner_loop.stop = stop
return len(coordinates) return coordinates_taken_into_account