diff --git a/pyproject.toml b/pyproject.toml index 5ef106e59e61780f49518cc33c6fc0dd21daf853..cc33e2b655af59480ec512af6c60a43207a85d31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ authors = [ ] license = { file = "COPYING.txt" } requires-python = ">=3.10" -dependencies = ["sympy>=1.6,<=1.11.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"] +dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"] classifiers = [ "Development Status :: 4 - Beta", "Framework :: Jupyter", diff --git a/src/pystencils/__init__.py b/src/pystencils/__init__.py index c39cd3b826c733da9dceb1944f0cce5038c348e0..ac3801518fbe6a0e2db9b901eeee4013698c0f70 100644 --- a/src/pystencils/__init__.py +++ b/src/pystencils/__init__.py @@ -15,7 +15,7 @@ from .config import ( OpenMpConfig, ) from .kernel_decorator import kernel, kernel_config -from .kernelcreation import create_kernel +from .kernelcreation import create_kernel, create_staggered_kernel from .backend.kernelfunction import KernelFunction from .slicing import make_slice from .spatial_coordinates import ( @@ -28,12 +28,13 @@ from .spatial_coordinates import ( z_, z_staggered, ) -from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment -from .sympyextensions.astnodes import assignment_from_stencil +from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil +from .simp import AssignmentCollection from .sympyextensions.typed_sympy import TypedSymbol from .sympyextensions.math import SymbolCreator from .datahandling import create_data_handling + __all__ = [ "Field", "FieldType", @@ -48,6 +49,7 @@ __all__ = [ "VectorizationConfig", "OpenMpConfig", "create_kernel", + "create_staggered_kernel", "KernelFunction", "Target", "show_code", diff --git a/src/pystencils/assignment.py b/src/pystencils/assignment.py new file mode 100644 index 0000000000000000000000000000000000000000..af32bc664d7e64a4d57bbe6cf782b86af6904a8b --- /dev/null +++ b/src/pystencils/assignment.py @@ -0,0 +1,110 @@ +import numpy as np +import sympy as sp +from sympy.codegen.ast import Assignment, AugmentedAssignment +from sympy.codegen.ast import AddAugmentedAssignment as SpAddAugAssignment +from sympy.printing.latex import LatexPrinter + +__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil'] + + +def print_assignment_latex(printer, expr): + binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else '' + """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" + printed_lhs = printer.doprint(expr.lhs) + printed_rhs = printer.doprint(expr.rhs) + return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}" + + +def assignment_str(assignment): + op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←' + return fr"{assignment.lhs} {op} {assignment.rhs}" + + +_old_new = sp.codegen.ast.Assignment.__new__ + + +# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults +def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): + if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): + assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' + return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) + return _old_new(cls, lhs, rhs, *args, **kwargs) + + +Assignment.__str__ = assignment_str +Assignment.__new__ = _Assignment__new__ +LatexPrinter._print_Assignment = print_assignment_latex + +AugmentedAssignment.__str__ = assignment_str +LatexPrinter._print_AugmentedAssignment = print_assignment_latex + +sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) + +# Re-Export +AddAugmentedAssignment = SpAddAugAssignment + + +def assignment_from_stencil(stencil_array, input_field, output_field, + normalization_factor=None, order='visual') -> Assignment: + """Creates an assignment + + Args: + stencil_array: nested list of numpy array defining the stencil weights + input_field: field or field access, defining where the stencil should be applied to + output_field: field or field access where the result is written to + normalization_factor: optional normalization factor for the stencil + order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'. + For details see examples + + Returns: + Assignment that can be used to create a kernel + + Examples: + >>> import pystencils as ps + >>> f, g = ps.fields("f, g: [2D]") + >>> stencil = [[0, 2, 0], + ... [3, 4, 5], + ... [0, 6, 0]] + + By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down + >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0]) + >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output + True + + 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. + >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0]) + >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output + True + + You can also pass field accesses to apply the stencil at an already shifted position: + >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0]) + >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output + True + """ + from pystencils.field import Field + + stencil_array = np.array(stencil_array) + if order == 'visual': + stencil_array = np.swapaxes(stencil_array, 0, 1) + stencil_array = np.flip(stencil_array, axis=1) + elif order == 'numpy': + pass + else: + raise ValueError("'order' has to be either 'visual' or 'numpy'") + + if isinstance(input_field, Field): + input_field = input_field.center + if isinstance(output_field, Field): + output_field = output_field.center + + rhs = 0 + offset = tuple(s // 2 for s in stencil_array.shape) + + for index, factor in np.ndenumerate(stencil_array): + shift = tuple(i - o for i, o in zip(index, offset)) + rhs += factor * input_field.get_shifted(*shift) + + if normalization_factor: + rhs *= normalization_factor + + return Assignment(output_field, rhs) diff --git a/src/pystencils/backend/kernelcreation/analysis.py b/src/pystencils/backend/kernelcreation/analysis.py index a72191b5b6f001c21e745c892a991c6058fb0048..05aa7992819be66ed7816882f9eb96372e143d64 100644 --- a/src/pystencils/backend/kernelcreation/analysis.py +++ b/src/pystencils/backend/kernelcreation/analysis.py @@ -9,7 +9,8 @@ import sympy as sp from .context import KernelCreationContext from ...field import Field -from ...sympyextensions import Assignment, AssignmentCollection +from ...assignment import Assignment +from ...simp import AssignmentCollection from ..exceptions import PsInternalCompilerError, KernelConstraintsError diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index 59fa04b3bdb6af98ac584ef7d166167c64423862..f81ed586bc60b6d831ea4b03ff37612841fc7e78 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -7,7 +7,8 @@ import sympy.core.relational import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment -from ...sympyextensions.astnodes import Assignment, AssignmentCollection +from ...assignment import Assignment +from ...simp import AssignmentCollection from ...sympyextensions import ( integer_functions, ConditionalFieldAccess, diff --git a/src/pystencils/backend/kernelcreation/iteration_space.py b/src/pystencils/backend/kernelcreation/iteration_space.py index 2a3d2774e03160fe2012f68ecb3ded4803353304..5208c906caa86a3b687057d777908daaf5f6e2ab 100644 --- a/src/pystencils/backend/kernelcreation/iteration_space.py +++ b/src/pystencils/backend/kernelcreation/iteration_space.py @@ -6,7 +6,7 @@ from functools import reduce from operator import mul from ...defaults import DEFAULTS -from ...sympyextensions import AssignmentCollection +from ...simp import AssignmentCollection from ...field import Field, FieldType from ..symbols import PsSymbol diff --git a/src/pystencils/boundaries/boundaryconditions.py b/src/pystencils/boundaries/boundaryconditions.py index f52573bca41723fc6e71b208227f72cbbc343e5b..cf6a3e82454d2ddbdfe98344479b8871d4d7b53b 100644 --- a/src/pystencils/boundaries/boundaryconditions.py +++ b/src/pystencils/boundaries/boundaryconditions.py @@ -1,6 +1,6 @@ from typing import Any, List, Tuple, Sequence -from pystencils.sympyextensions import Assignment +from pystencils.assignment import Assignment from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo from pystencils.types import create_type diff --git a/src/pystencils/boundaries/boundaryhandling.py b/src/pystencils/boundaries/boundaryhandling.py index 57a1cd95f4cb2e166359e60684a2fe9687e9d320..f171d56091f69fbd1bef2a2edc4cf844a96a9f40 100644 --- a/src/pystencils/boundaries/boundaryhandling.py +++ b/src/pystencils/boundaries/boundaryhandling.py @@ -4,7 +4,7 @@ import numpy as np import sympy as sp from pystencils import create_kernel, CreateKernelConfig, Target -from pystencils.sympyextensions import Assignment +from pystencils.assignment import Assignment from pystencils.boundaries.createindexlist import ( create_boundary_index_array, numpy_data_type_for_boundary_object) from pystencils.sympyextensions import TypedSymbol diff --git a/src/pystencils/fd/finitedifferences.py b/src/pystencils/fd/finitedifferences.py index 9c4116ee56eff340b222132a66ecf37740af1b31..f34a448ed8401a9b4d7e9f145c4883203f4d1ca9 100644 --- a/src/pystencils/fd/finitedifferences.py +++ b/src/pystencils/fd/finitedifferences.py @@ -7,7 +7,7 @@ from pystencils.fd import Diff from pystencils.fd.derivative import diff_args from pystencils.fd.spatial import fd_stencils_standard from pystencils.field import Field -from pystencils.sympyextensions import AssignmentCollection +from pystencils.simp import AssignmentCollection from pystencils.sympyextensions.math import fast_subs FieldOrFieldAccess = Union[Field, Field.Access] diff --git a/src/pystencils/kernel_decorator.py b/src/pystencils/kernel_decorator.py index deb94eec0eb5d77f2b67411a45b8c9b23da53e0f..ce0a31d546acf30a418bbc0b2554a2f932a9e180 100644 --- a/src/pystencils/kernel_decorator.py +++ b/src/pystencils/kernel_decorator.py @@ -5,7 +5,7 @@ from typing import Callable, Union, List, Dict, Tuple import sympy as sp -from .sympyextensions import Assignment +from .assignment import Assignment from .sympyextensions.math import SymbolCreator from pystencils.config import CreateKernelConfig diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 3cda5aa46313d46251ef9c73c6348e2f65c1af54..154cb2307bdaaec0167632d4c77531632b1fe9e8 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -152,3 +152,7 @@ def create_kernel_function( return KernelFunction( body, target_spec, function_name, params, req_headers, ctx.constraints, jit ) + + +def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs): + raise NotImplementedError("Staggered kernels are not yet implemented for pystencils 2.0") diff --git a/src/pystencils/placeholder_function.py b/src/pystencils/placeholder_function.py index 00acb17bd71cdd7cfb628d89e5e1c85034c449ce..e9a3a0aba7c7b2b81e050029d0b06899a247c1ee 100644 --- a/src/pystencils/placeholder_function.py +++ b/src/pystencils/placeholder_function.py @@ -2,7 +2,7 @@ from typing import List import sympy as sp -from pystencils.sympyextensions import Assignment +from .assignment import Assignment from pystencils.sympyextensions import is_constant from pystencils.sympyextensions.astnodes import generic_visit diff --git a/src/pystencils/sympyextensions/rng.py b/src/pystencils/rng.py similarity index 97% rename from src/pystencils/sympyextensions/rng.py rename to src/pystencils/rng.py index 859669a6ac35e97a13646efcf0274446bc379988..d6c6cd2741ee3e7442bd9fa4a96f4e9983d524e3 100644 --- a/src/pystencils/sympyextensions/rng.py +++ b/src/pystencils/rng.py @@ -2,10 +2,9 @@ import copy import numpy as np import sympy as sp -from pystencils.sympyextensions import TypedSymbol, CastFunc +from .sympyextensions import TypedSymbol, CastFunc, fast_subs # from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace # from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace -from pystencils.sympyextensions import fast_subs # class RNGBase(CustomCodeNode): TODO nbackend: replace diff --git a/src/pystencils/simp/__init__.py b/src/pystencils/simp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6c553af8bcb16d43e77ddf2515db73a9bde5f1af --- /dev/null +++ b/src/pystencils/simp/__init__.py @@ -0,0 +1,45 @@ +from .assignment_collection import AssignmentCollection +from .simplifications import ( + add_subexpressions_for_constants, + add_subexpressions_for_divisions, + add_subexpressions_for_field_reads, + add_subexpressions_for_sums, + apply_on_all_subexpressions, + apply_to_all_assignments, + subexpression_substitution_in_existing_subexpressions, + subexpression_substitution_in_main_assignments, + sympy_cse, + sympy_cse_on_assignment_list, +) +from .subexpression_insertion import ( + insert_aliases, + insert_zeros, + insert_constants, + insert_constant_additions, + insert_constant_multiples, + insert_squares, + insert_symbol_times_minus_one, +) +from .simplificationstrategy import SimplificationStrategy + +__all__ = [ + "AssignmentCollection", + "SimplificationStrategy", + "sympy_cse", + "sympy_cse_on_assignment_list", + "apply_to_all_assignments", + "apply_on_all_subexpressions", + "subexpression_substitution_in_existing_subexpressions", + "subexpression_substitution_in_main_assignments", + "add_subexpressions_for_constants", + "add_subexpressions_for_divisions", + "add_subexpressions_for_sums", + "add_subexpressions_for_field_reads", + "insert_aliases", + "insert_zeros", + "insert_constants", + "insert_constant_additions", + "insert_constant_multiples", + "insert_squares", + "insert_symbol_times_minus_one", +] diff --git a/src/pystencils/simp/assignment_collection.py b/src/pystencils/simp/assignment_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ba8715431d96fb2a09a01e45872def421fe94f --- /dev/null +++ b/src/pystencils/simp/assignment_collection.py @@ -0,0 +1,476 @@ +import itertools +from copy import copy +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union + +import sympy as sp + +import pystencils +from ..assignment import Assignment +from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) +from ..sympyextensions import count_operations, fast_subs + + +class AssignmentCollection: + """ + A collection of equations with subexpression definitions, also represented as assignments, + that are used in the main equations. AssignmentCollection can be passed to simplification methods. + These simplification methods can change the subexpressions, but the number and + left hand side of the main equations themselves is not altered. + Additionally a dictionary of simplification hints is stored, which are set by the functions that create + assignment collections to transport information to the simplification system. + + Args: + main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each + assignment is a field access. Thus the generated equations write on arrays. + subexpressions: List of assignments defining subexpressions used in main equations + simplification_hints: Dict that is used to annotate the assignment collection with hints that are + used by the simplification system. See documentation of the simplification rules for + potentially required hints and their meaning. + subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added + used to get new symbols that are unique for this AssignmentCollection + + """ + + __match_args__ = ("main_assignments", "subexpressions") + + # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- + + def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], + subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None, + simplification_hints: Optional[Dict[str, Any]] = None, + subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: + + if subexpressions is None: + subexpressions = {} + + if isinstance(main_assignments, Dict): + main_assignments = [Assignment(k, v) + for k, v in main_assignments.items()] + if isinstance(subexpressions, Dict): + subexpressions = [Assignment(k, v) + for k, v in subexpressions.items()] + + main_assignments = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) + subexpressions = list(itertools.chain.from_iterable( + [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) + + self.main_assignments = main_assignments + self.subexpressions = subexpressions + + if simplification_hints is None: + simplification_hints = {} + + self.simplification_hints = simplification_hints + + ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name] + max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0 + + if subexpression_symbol_generator is None: + self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr) + else: + self.subexpression_symbol_generator = subexpression_symbol_generator + + def add_simplification_hint(self, key: str, value: Any) -> None: + """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" + assert key not in self.simplification_hints, "This hint already exists" + self.simplification_hints[key] = value + + def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: + """Adds a subexpression to current collection. + + Args: + rhs: right hand side of new subexpression + lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. + topological_sort: sort the subexpressions topologically after insertion, to make sure that + definition of a symbol comes before its usage. If False, subexpression is appended. + + Returns: + left hand side symbol (which could have been generated) + """ + if lhs is None: + lhs = next(self.subexpression_symbol_generator) + eq = Assignment(lhs, rhs) + self.subexpressions.append(eq) + if topological_sort: + self.topological_sort(sort_subexpressions=True, + sort_main_assignments=False) + return lhs + + def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: + """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" + if sort_subexpressions: + self.subexpressions = sort_assignments_topologically(self.subexpressions) + if sort_main_assignments: + self.main_assignments = sort_assignments_topologically(self.main_assignments) + + # ---------------------------------------------- Properties ------------------------------------------------------- + + @property + def all_assignments(self) -> List[Assignment]: + """Subexpression and main equations as a single list.""" + return self.subexpressions + self.main_assignments + + @property + def rhs_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which occur on the rhs of any assignment.""" + rhs_symbols = set() + for eq in self.all_assignments: + if isinstance(eq, Assignment): + rhs_symbols.update(eq.rhs.atoms(sp.Symbol)) + # elif isinstance(eq, pystencils.astnodes.Node): # TODO remove or replace + # rhs_symbols.update(eq.undefined_symbols) + + return rhs_symbols + + @property + def free_symbols(self) -> Set[sp.Symbol]: + """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" + return self.rhs_symbols - self.bound_symbols + + @property + def bound_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur on the left hand side of a main assignment or a subexpression.""" + bound_symbols_set = set( + [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] + ) + + assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ + "Not in SSA form - same symbol assigned multiple times" + + # bound_symbols_set = bound_symbols_set.union(*[ + # assignment.symbols_defined for assignment in self.all_assignments + # if isinstance(assignment, pystencils.astnodes.Node) + # ]) TODO: replace? + + return bound_symbols_set + + @property + def rhs_fields(self): + """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" + return {s.field for s in self.rhs_symbols if hasattr(s, 'field')} + + @property + def free_fields(self): + """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" + return {s.field for s in self.free_symbols if hasattr(s, 'field')} + + @property + def bound_fields(self): + """All field accessed on the left hand side of a main assignment or a subexpression.""" + return {s.field for s in self.bound_symbols if hasattr(s, 'field')} + + @property + def defined_symbols(self) -> Set[sp.Symbol]: + """All symbols which occur as left-hand-sides of one of the main equations""" + lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]) + return lhs_set + # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments + # if isinstance(assignment, pystencils.astnodes.Node)])) TODO + + @property + def operation_count(self): + """See :func:`count_operations` """ + return count_operations(self.all_assignments, only_type=None) + + def atoms(self, *args): + return set().union(*[a.atoms(*args) for a in self.all_assignments]) + + def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: + """Returns all symbols that depend on one of the passed symbols. + + A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when + 'b' is required to compute 'a'. + """ + + queue = list(symbols) + + def add_symbols_from_expr(expr): + dependent_symbols = expr.atoms(sp.Symbol) + for ds in dependent_symbols: + queue.append(ds) + + handled_symbols = set() + assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} + + while len(queue) > 0: + e = queue.pop(0) + if e in handled_symbols: + continue + if e in assignment_dict: + add_symbols_from_expr(assignment_dict[e]) + handled_symbols.add(e) + + return handled_symbols + + def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None): + """Returns a python function to evaluate this equation collection. + + Args: + symbols: symbol(s) which are the parameter for the created function + fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify + module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' + + Examples: + >>> a, b, c, d = sp.symbols("a b c d") + >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], + ... subexpressions=[Assignment(b, a + b / 2)]) + >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) + >>> python_function(4) + {c: 6, d: 18} + """ + assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self + assignments = assignments.new_without_subexpressions().main_assignments + lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} + + def f(*args, **kwargs): + return {s: func(*args, **kwargs) for s, func in lambdas.items()} + + return f + + # ---------------------------- Creating new modified collections --------------------------------------------------- + + def copy(self, + main_assignments: Optional[List[Assignment]] = None, + subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': + """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" + + res = copy(self) + res.simplification_hints = self.simplification_hints.copy() + res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) + + if main_assignments is not None: + res.main_assignments = main_assignments + else: + res.main_assignments = self.main_assignments.copy() + + if subexpressions is not None: + res.subexpressions = subexpressions + else: + res.subexpressions = self.subexpressions.copy() + + return res + + def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, + substitute_on_lhs: bool = True, + sort_topologically: bool = True) -> 'AssignmentCollection': + """Returns new object, where terms are substituted according to the passed substitution dict. + + Args: + substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions + add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions + substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments + sort_topologically: if subexpressions are added as substitutions and this parameters is true, + the subexpressions are sorted topologically after insertion + Returns: + New AssignmentCollection where substitutions have been applied, self is not altered. + """ + transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs + transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) + transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) + + if add_substitutions_as_subexpressions: + transformed_subexpressions = [Assignment(b, a) for a, b in + substitutions.items()] + transformed_subexpressions + if sort_topologically: + transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) + return self.copy(transformed_assignments, transformed_subexpressions) + + def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': + """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" + own_definitions = set([e.lhs for e in self.main_assignments]) + other_definitions = set([e.lhs for e in other.main_assignments]) + assert len(own_definitions.intersection(other_definitions)) == 0, \ + "Cannot merge collections, since both define the same symbols" + + own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} + substitution_dict = {} + + processed_other_subexpression_equations = [] + for other_subexpression_eq in other.subexpressions: + if other_subexpression_eq.lhs in own_subexpression_symbols: + new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict) + if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: + continue # exact the same subexpression equation exists already + else: + # different definition - a new name has to be introduced + new_lhs = next(self.subexpression_symbol_generator) + new_eq = Assignment(new_lhs, new_rhs) + processed_other_subexpression_equations.append(new_eq) + substitution_dict[other_subexpression_eq.lhs] = new_lhs + else: + processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) + + processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] + return self.copy(self.main_assignments + processed_other_main_assignments, + self.subexpressions + processed_other_subexpression_equations) + + def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': + """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. + + Returns: + new AssignmentCollection, self is not altered + """ + symbols_to_extract = set(symbols_to_extract) + dependent_symbols = self.dependent_symbols(symbols_to_extract) + new_assignments = [] + for eq in self.all_assignments: + if eq.lhs in symbols_to_extract: + new_assignments.append(eq) + + new_sub_expr = [eq for eq in self.all_assignments + if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] + return self.copy(new_assignments, new_sub_expr) + + def new_without_unused_subexpressions(self) -> 'AssignmentCollection': + """Returns new collection that only contains subexpressions required to compute the main assignments.""" + all_lhs = [eq.lhs for eq in self.main_assignments] + return self.new_filtered(all_lhs) + + def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': + """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" + new_subexpressions = [] + subs_dict = None + for se in self.subexpressions: + if se.lhs == symbol: + subs_dict = {se.lhs: se.rhs} + else: + new_subexpressions.append(se) + if subs_dict is None: + return self + + new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] + new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] + return self.copy(new_eqs, new_subexpressions) + + def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection': + """Returns a new collection where all subexpressions have been inserted.""" + if subexpressions_to_keep is None: + subexpressions_to_keep = set() + if len(self.subexpressions) == 0: + return self.copy() + + subexpressions_to_keep = set(subexpressions_to_keep) + + kept_subexpressions = [] + if self.subexpressions[0].lhs in subexpressions_to_keep: + substitution_dict = {} + kept_subexpressions.append(self.subexpressions[0]) + else: + substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} + + subexpression = [e for e in self.subexpressions] + for i in range(1, len(subexpression)): + subexpression[i] = fast_subs(subexpression[i], substitution_dict) + if subexpression[i].lhs in subexpressions_to_keep: + kept_subexpressions.append(subexpression[i]) + else: + substitution_dict[subexpression[i].lhs] = subexpression[i].rhs + + new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] + return self.copy(new_assignment, kept_subexpressions) + + # ----------------------------------------- Display and Printing ------------------------------------------------- + + def _repr_html_(self): + """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" + + def make_html_equation_table(equations): + no_border = 'style="border:none"' + html_table = '<table style="border:none; width: 100%; ">' + line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' + for eq in equations: + format_dict = {'eq': sp.latex(eq), + 'nb': no_border, } + html_table += line.format(**format_dict) + html_table += "</table>" + return html_table + + result = "" + if len(self.subexpressions) > 0: + result += "<div>Subexpressions:</div>" + result += make_html_equation_table(self.subexpressions) + result += "<div>Main Assignments:</div>" + result += make_html_equation_table(self.main_assignments) + return result + + def __repr__(self): + return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}" + + def __str__(self): + result = "Subexpressions:\n" + for eq in self.subexpressions: + result += f"\t{eq}\n" + result += "Main Assignments:\n" + for eq in self.main_assignments: + result += f"\t{eq}\n" + return result + + def __iter__(self): + return self.all_assignments.__iter__() + + @property + def main_assignments_dict(self): + return {a.lhs: a.rhs for a in self.main_assignments} + + @property + def subexpressions_dict(self): + return {a.lhs: a.rhs for a in self.subexpressions} + + def set_main_assignments_from_dict(self, main_assignments_dict): + self.main_assignments = [Assignment(k, v) + for k, v in main_assignments_dict.items()] + + def set_sub_expressions_from_dict(self, sub_expressions_dict): + self.subexpressions = [Assignment(k, v) + for k, v in sub_expressions_dict.items()] + + def find(self, *args, **kwargs): + return set.union( + *[a.find(*args, **kwargs) for a in self.all_assignments] + ) + + def match(self, *args, **kwargs): + rtn = {} + for a in self.all_assignments: + partial_result = a.match(*args, **kwargs) + if partial_result: + rtn.update(partial_result) + return rtn + + def subs(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] + ) + + def replace(self, *args, **kwargs): + return AssignmentCollection( + main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], + subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] + ) + + def __eq__(self, other): + return set(self.all_assignments) == set(other.all_assignments) + + def __bool__(self): + return bool(self.all_assignments) + + +class SymbolGen: + """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" + + def __init__(self, symbol="xi", dtype=None, ctr=0): + self._ctr = ctr + self._symbol = symbol + self._dtype = dtype + + def __iter__(self): + return self + + def __next__(self): + name = f"{self._symbol}_{self._ctr}" + self._ctr += 1 + if self._dtype is not None: + return pystencils.TypedSymbol(name, self._dtype) + return sp.Symbol(name) diff --git a/src/pystencils/sympyextensions/simplifications.py b/src/pystencils/simp/simplifications.py similarity index 97% rename from src/pystencils/sympyextensions/simplifications.py rename to src/pystencils/simp/simplifications.py index cdcad81e772d7c837cfd761a287b777f94397078..73d80ecd4524a4e2e3890b3c62c8d7307373b29a 100644 --- a/src/pystencils/sympyextensions/simplifications.py +++ b/src/pystencils/simp/simplifications.py @@ -4,9 +4,9 @@ from collections import defaultdict import sympy as sp -from .astnodes import Assignment -from .math import subs_additive, is_constant, recursive_collect -from .typed_sympy import TypedSymbol +from ..assignment import Assignment +from ..sympyextensions.math import subs_additive, is_constant, recursive_collect +from ..sympyextensions.typed_sympy import TypedSymbol # TODO rewrite with SymPy AST @@ -55,7 +55,7 @@ def sympy_cse(ac, **kwargs): def sympy_cse_on_assignment_list(assignments: List[Assignment]) -> List[Assignment]: """Extracts common subexpressions from a list of assignments.""" - from pystencils.sympyextensions import AssignmentCollection + from pystencils.simp import AssignmentCollection ec = AssignmentCollection([], assignments) return sympy_cse(ec).all_assignments diff --git a/src/pystencils/sympyextensions/simplificationstrategy.py b/src/pystencils/simp/simplificationstrategy.py similarity index 99% rename from src/pystencils/sympyextensions/simplificationstrategy.py rename to src/pystencils/simp/simplificationstrategy.py index b76d12711db5234aa9d29a7fd5608e2bab010178..22ffa34d04bc2731f615bd685137c8abebf9d58b 100644 --- a/src/pystencils/sympyextensions/simplificationstrategy.py +++ b/src/pystencils/simp/simplificationstrategy.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence import sympy as sp -from .astnodes import AssignmentCollection +from ..simp import AssignmentCollection class SimplificationStrategy: diff --git a/src/pystencils/sympyextensions/subexpression_insertion.py b/src/pystencils/simp/subexpression_insertion.py similarity index 98% rename from src/pystencils/sympyextensions/subexpression_insertion.py rename to src/pystencils/simp/subexpression_insertion.py index 8cedad4665033e91a6812d4c9c23634c8d208c70..0c5aca1ddde2c4f5384fbfb004cada3c5397e92b 100644 --- a/src/pystencils/sympyextensions/subexpression_insertion.py +++ b/src/pystencils/simp/subexpression_insertion.py @@ -1,5 +1,5 @@ import sympy as sp -from .math import is_constant +from ..sympyextensions.math import is_constant # Subexpression Insertion diff --git a/src/pystencils/sympyextensions/__init__.py b/src/pystencils/sympyextensions/__init__.py index fd1145bcbc0422247873c9c3c4864a50f80db48d..0b8b3690d59a74342dc8510fba531928c7625539 100644 --- a/src/pystencils/sympyextensions/__init__.py +++ b/src/pystencils/sympyextensions/__init__.py @@ -1,34 +1,5 @@ -from .astnodes import ( - Assignment, - AugmentedAssignment, - AddAugmentedAssignment, - AssignmentCollection, - SymbolGen, - ConditionalFieldAccess -) +from .astnodes import ConditionalFieldAccess from .typed_sympy import TypedSymbol, CastFunc -from .simplificationstrategy import SimplificationStrategy -from .simplifications import ( - sympy_cse, - sympy_cse_on_assignment_list, - apply_to_all_assignments, - apply_on_all_subexpressions, - subexpression_substitution_in_existing_subexpressions, - subexpression_substitution_in_main_assignments, - add_subexpressions_for_constants, - add_subexpressions_for_divisions, - add_subexpressions_for_sums, - add_subexpressions_for_field_reads -) -from .subexpression_insertion import ( - insert_aliases, - insert_zeros, - insert_constants, - insert_constant_additions, - insert_constant_multiples, - insert_squares, - insert_symbol_times_minus_one, -) from .math import ( prod, @@ -59,32 +30,9 @@ from .math import ( __all__ = [ - "Assignment", - "AugmentedAssignment", - "AddAugmentedAssignment", - "AssignmentCollection", - "SymbolGen", "ConditionalFieldAccess", "TypedSymbol", "CastFunc", - "SimplificationStrategy", - "sympy_cse", - "sympy_cse_on_assignment_list", - "apply_to_all_assignments", - "apply_on_all_subexpressions", - "subexpression_substitution_in_existing_subexpressions", - "subexpression_substitution_in_main_assignments", - "add_subexpressions_for_constants", - "add_subexpressions_for_divisions", - "add_subexpressions_for_sums", - "add_subexpressions_for_field_reads", - "insert_aliases", - "insert_zeros", - "insert_constants", - "insert_constant_additions", - "insert_constant_multiples", - "insert_squares", - "insert_symbol_times_minus_one", "remove_higher_order_terms", "prod", "remove_small_floats", diff --git a/src/pystencils/sympyextensions/astnodes.py b/src/pystencils/sympyextensions/astnodes.py index 68613b41a315959e97c3e86908a48ffd5892c9f2..74906cc5ccf9150feac1fd291e3363c90408d295 100644 --- a/src/pystencils/sympyextensions/astnodes.py +++ b/src/pystencils/sympyextensions/astnodes.py @@ -1,592 +1,4 @@ -from copy import copy -import itertools -import uuid -from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union - import sympy as sp -from sympy.codegen.ast import Assignment, AugmentedAssignment -from sympy.codegen.ast import AddAugmentedAssignment as SpAddAugAssignment -from sympy.printing.latex import LatexPrinter -import numpy as np - -from .math import count_operations, fast_subs -from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) -from .typed_sympy import create_type, TypedSymbol - - -def print_assignment_latex(printer, expr): - binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else '' - """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" - printed_lhs = printer.doprint(expr.lhs) - printed_rhs = printer.doprint(expr.rhs) - return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}" - - -def assignment_str(assignment): - op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else '←' - return fr"{assignment.lhs} {op} {assignment.rhs}" - - -_old_new = sp.codegen.ast.Assignment.__new__ - - -# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults -def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): - if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)): - assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!' - return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) - return _old_new(cls, lhs, rhs, *args, **kwargs) - - -Assignment.__str__ = assignment_str -Assignment.__new__ = _Assignment__new__ -LatexPrinter._print_Assignment = print_assignment_latex - -AugmentedAssignment.__str__ = assignment_str -LatexPrinter._print_AugmentedAssignment = print_assignment_latex - -sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) - -# Re-Export -AddAugmentedAssignment = SpAddAugAssignment - - -def assignment_from_stencil(stencil_array, input_field, output_field, - normalization_factor=None, order='visual') -> Assignment: - """Creates an assignment - - Args: - stencil_array: nested list of numpy array defining the stencil weights - input_field: field or field access, defining where the stencil should be applied to - output_field: field or field access where the result is written to - normalization_factor: optional normalization factor for the stencil - order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'. - For details see examples - - Returns: - Assignment that can be used to create a kernel - - Examples: - >>> import pystencils as ps - >>> f, g = ps.fields("f, g: [2D]") - >>> stencil = [[0, 2, 0], - ... [3, 4, 5], - ... [0, 6, 0]] - - By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down - >>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0]) - >>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output - True - - 'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc. - >>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0]) - >>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output - True - - You can also pass field accesses to apply the stencil at an already shifted position: - >>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0]) - >>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output - True - """ - from pystencils.field import Field - - stencil_array = np.array(stencil_array) - if order == 'visual': - stencil_array = np.swapaxes(stencil_array, 0, 1) - stencil_array = np.flip(stencil_array, axis=1) - elif order == 'numpy': - pass - else: - raise ValueError("'order' has to be either 'visual' or 'numpy'") - - if isinstance(input_field, Field): - input_field = input_field.center - if isinstance(output_field, Field): - output_field = output_field.center - - rhs = 0 - offset = tuple(s // 2 for s in stencil_array.shape) - - for index, factor in np.ndenumerate(stencil_array): - shift = tuple(i - o for i, o in zip(index, offset)) - rhs += factor * input_field.get_shifted(*shift) - - if normalization_factor: - rhs *= normalization_factor - - return Assignment(output_field, rhs) - - -class AssignmentCollection: - """ - A collection of equations with subexpression definitions, also represented as assignments, - that are used in the main equations. AssignmentCollection can be passed to simplification methods. - These simplification methods can change the subexpressions, but the number and - left hand side of the main equations themselves is not altered. - Additionally a dictionary of simplification hints is stored, which are set by the functions that create - assignment collections to transport information to the simplification system. - - Args: - main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each - assignment is a field access. Thus the generated equations write on arrays. - subexpressions: List of assignments defining subexpressions used in main equations - simplification_hints: Dict that is used to annotate the assignment collection with hints that are - used by the simplification system. See documentation of the simplification rules for - potentially required hints and their meaning. - subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added - used to get new symbols that are unique for this AssignmentCollection - - """ - - __match_args__ = ("main_assignments", "subexpressions") - - # ------------------------------- Creation & Inplace Manipulation -------------------------------------------------- - - def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]], - subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None, - simplification_hints: Optional[Dict[str, Any]] = None, - subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None: - - if subexpressions is None: - subexpressions = {} - - if isinstance(main_assignments, Dict): - main_assignments = [Assignment(k, v) - for k, v in main_assignments.items()] - if isinstance(subexpressions, Dict): - subexpressions = [Assignment(k, v) - for k, v in subexpressions.items()] - - main_assignments = list(itertools.chain.from_iterable( - [(a if isinstance(a, Iterable) else [a]) for a in main_assignments])) - subexpressions = list(itertools.chain.from_iterable( - [(a if isinstance(a, Iterable) else [a]) for a in subexpressions])) - - self.main_assignments = main_assignments - self.subexpressions = subexpressions - - if simplification_hints is None: - simplification_hints = {} - - self.simplification_hints = simplification_hints - - ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name] - max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0 - - if subexpression_symbol_generator is None: - self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr) - else: - self.subexpression_symbol_generator = subexpression_symbol_generator - - def add_simplification_hint(self, key: str, value: Any) -> None: - """Adds an entry to the simplification_hints dictionary and checks that is does not exist yet.""" - assert key not in self.simplification_hints, "This hint already exists" - self.simplification_hints[key] = value - - def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol: - """Adds a subexpression to current collection. - - Args: - rhs: right hand side of new subexpression - lhs: optional left hand side of new subexpression. If None a new unique symbol is generated. - topological_sort: sort the subexpressions topologically after insertion, to make sure that - definition of a symbol comes before its usage. If False, subexpression is appended. - - Returns: - left hand side symbol (which could have been generated) - """ - if lhs is None: - lhs = next(self.subexpression_symbol_generator) - eq = Assignment(lhs, rhs) - self.subexpressions.append(eq) - if topological_sort: - self.topological_sort(sort_subexpressions=True, - sort_main_assignments=False) - return lhs - - def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None: - """Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition.""" - if sort_subexpressions: - self.subexpressions = sort_assignments_topologically(self.subexpressions) - if sort_main_assignments: - self.main_assignments = sort_assignments_topologically(self.main_assignments) - - # ---------------------------------------------- Properties ------------------------------------------------------- - - @property - def all_assignments(self) -> List[Assignment]: - """Subexpression and main equations as a single list.""" - return self.subexpressions + self.main_assignments - - @property - def rhs_symbols(self) -> Set[sp.Symbol]: - """All symbols used in the assignment collection, which occur on the rhs of any assignment.""" - rhs_symbols = set() - for eq in self.all_assignments: - if isinstance(eq, Assignment): - rhs_symbols.update(eq.rhs.atoms(sp.Symbol)) - # TODO rewrite with SymPy AST - # elif isinstance(eq, pystencils.astnodes.Node): - # rhs_symbols.update(eq.undefined_symbols) - - return rhs_symbols - - @property - def free_symbols(self) -> Set[sp.Symbol]: - """All symbols used in the assignment collection, which do not occur as left hand sides in any assignment.""" - return self.rhs_symbols - self.bound_symbols - - @property - def bound_symbols(self) -> Set[sp.Symbol]: - """All symbols which occur on the left hand side of a main assignment or a subexpression.""" - bound_symbols_set = set( - [assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)] - ) - - assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \ - "Not in SSA form - same symbol assigned multiple times" - - # TODO rewrite with SymPy AST - # bound_symbols_set = bound_symbols_set.union(*[ - # assignment.symbols_defined for assignment in self.all_assignments - # if isinstance(assignment, pystencils.astnodes.Node) - # ]) - - return bound_symbols_set - - @property - def rhs_fields(self): - """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" - return {s.field for s in self.rhs_symbols if hasattr(s, 'field')} - - @property - def free_fields(self): - """All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment.""" - return {s.field for s in self.free_symbols if hasattr(s, 'field')} - - @property - def bound_fields(self): - """All field accessed on the left hand side of a main assignment or a subexpression.""" - return {s.field for s in self.bound_symbols if hasattr(s, 'field')} - - @property - def defined_symbols(self) -> Set[sp.Symbol]: - """All symbols which occur as left-hand-sides of one of the main equations""" - lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)]) - return lhs_set - # TODO rewrite with SymPy AST - # return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments - # if isinstance(assignment, pystencils.astnodes.Node)])) - - @property - def operation_count(self): - """See :func:`count_operations` """ - return count_operations(self.all_assignments, only_type=None) - - def atoms(self, *args): - return set().union(*[a.atoms(*args) for a in self.all_assignments]) - - def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]: - """Returns all symbols that depend on one of the passed symbols. - - A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when - 'b' is required to compute 'a'. - """ - - queue = list(symbols) - - def add_symbols_from_expr(expr): - dependent_symbols = expr.atoms(sp.Symbol) - for ds in dependent_symbols: - queue.append(ds) - - handled_symbols = set() - assignment_dict = {e.lhs: e.rhs for e in self.all_assignments} - - while len(queue) > 0: - e = queue.pop(0) - if e in handled_symbols: - continue - if e in assignment_dict: - add_symbols_from_expr(assignment_dict[e]) - handled_symbols.add(e) - - return handled_symbols - - def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None): - """Returns a python function to evaluate this equation collection. - - Args: - symbols: symbol(s) which are the parameter for the created function - fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify - module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy' - - Examples: - >>> a, b, c, d = sp.symbols("a b c d") - >>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)], - ... subexpressions=[Assignment(b, a + b / 2)]) - >>> python_function = ac.lambdify([a], fixed_symbols={b: 2}) - >>> python_function(4) - {c: 6, d: 18} - """ - assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self - assignments = assignments.new_without_subexpressions().main_assignments - lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments} - - def f(*args, **kwargs): - return {s: func(*args, **kwargs) for s, func in lambdas.items()} - - return f - - # ---------------------------- Creating new modified collections --------------------------------------------------- - - def copy(self, - main_assignments: Optional[List[Assignment]] = None, - subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection': - """Returns a copy with optionally replaced main_assignments and/or subexpressions.""" - - res = copy(self) - res.simplification_hints = self.simplification_hints.copy() - res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator) - - if main_assignments is not None: - res.main_assignments = main_assignments - else: - res.main_assignments = self.main_assignments.copy() - - if subexpressions is not None: - res.subexpressions = subexpressions - else: - res.subexpressions = self.subexpressions.copy() - - return res - - def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False, - substitute_on_lhs: bool = True, - sort_topologically: bool = True) -> 'AssignmentCollection': - """Returns new object, where terms are substituted according to the passed substitution dict. - - Args: - substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions - add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions - substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments - sort_topologically: if subexpressions are added as substitutions and this parameters is true, - the subexpressions are sorted topologically after insertion - Returns: - New AssignmentCollection where substitutions have been applied, self is not altered. - """ - transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs - transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions) - transformed_assignments = transform(self.main_assignments, fast_subs, substitutions) - - if add_substitutions_as_subexpressions: - transformed_subexpressions = [Assignment(b, a) for a, b in - substitutions.items()] + transformed_subexpressions - if sort_topologically: - transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions) - return self.copy(transformed_assignments, transformed_subexpressions) - - def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection': - """Returns a new collection which contains self and other. Subexpressions are renamed if they clash.""" - own_definitions = set([e.lhs for e in self.main_assignments]) - other_definitions = set([e.lhs for e in other.main_assignments]) - assert len(own_definitions.intersection(other_definitions)) == 0, \ - "Cannot merge collections, since both define the same symbols" - - own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions} - substitution_dict = {} - - processed_other_subexpression_equations = [] - for other_subexpression_eq in other.subexpressions: - if other_subexpression_eq.lhs in own_subexpression_symbols: - if other_subexpression_eq.rhs == own_subexpression_symbols[other_subexpression_eq.lhs]: - continue # exact the same subexpression equation exists already - else: - # different definition - a new name has to be introduced - new_lhs = next(self.subexpression_symbol_generator) - new_eq = Assignment(new_lhs, fast_subs(other_subexpression_eq.rhs, substitution_dict)) - processed_other_subexpression_equations.append(new_eq) - substitution_dict[other_subexpression_eq.lhs] = new_lhs - else: - processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict)) - - processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments] - return self.copy(self.main_assignments + processed_other_main_assignments, - self.subexpressions + processed_other_subexpression_equations) - - def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection': - """Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions. - - Returns: - new AssignmentCollection, self is not altered - """ - symbols_to_extract = set(symbols_to_extract) - dependent_symbols = self.dependent_symbols(symbols_to_extract) - new_assignments = [] - for eq in self.all_assignments: - if eq.lhs in symbols_to_extract: - new_assignments.append(eq) - - new_sub_expr = [eq for eq in self.all_assignments - if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract] - return self.copy(new_assignments, new_sub_expr) - - def new_without_unused_subexpressions(self) -> 'AssignmentCollection': - """Returns new collection that only contains subexpressions required to compute the main assignments.""" - all_lhs = [eq.lhs for eq in self.main_assignments] - return self.new_filtered(all_lhs) - - def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection': - """Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere.""" - new_subexpressions = [] - subs_dict = None - for se in self.subexpressions: - if se.lhs == symbol: - subs_dict = {se.lhs: se.rhs} - else: - new_subexpressions.append(se) - if subs_dict is None: - return self - - new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions] - new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] - return self.copy(new_eqs, new_subexpressions) - - def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection': - """Returns a new collection where all subexpressions have been inserted.""" - if subexpressions_to_keep is None: - subexpressions_to_keep = set() - if len(self.subexpressions) == 0: - return self.copy() - - subexpressions_to_keep = set(subexpressions_to_keep) - - kept_subexpressions = [] - if self.subexpressions[0].lhs in subexpressions_to_keep: - substitution_dict = {} - kept_subexpressions.append(self.subexpressions[0]) - else: - substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs} - - subexpression = [e for e in self.subexpressions] - for i in range(1, len(subexpression)): - subexpression[i] = fast_subs(subexpression[i], substitution_dict) - if subexpression[i].lhs in subexpressions_to_keep: - kept_subexpressions.append(subexpression[i]) - else: - substitution_dict[subexpression[i].lhs] = subexpression[i].rhs - - new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] - return self.copy(new_assignment, kept_subexpressions) - - # ----------------------------------------- Display and Printing ------------------------------------------------- - - def _repr_html_(self): - """Interface to Jupyter notebook, to display as a nicely formatted HTML table""" - - def make_html_equation_table(equations): - no_border = 'style="border:none"' - html_table = '<table style="border:none; width: 100%; ">' - line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> ' - for eq in equations: - format_dict = {'eq': sp.latex(eq), - 'nb': no_border, } - html_table += line.format(**format_dict) - html_table += "</table>" - return html_table - - result = "" - if len(self.subexpressions) > 0: - result += "<div>Subexpressions:</div>" - result += make_html_equation_table(self.subexpressions) - result += "<div>Main Assignments:</div>" - result += make_html_equation_table(self.main_assignments) - return result - - def __repr__(self): - return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}" - - def __str__(self): - result = "Subexpressions:\n" - for eq in self.subexpressions: - result += f"\t{eq}\n" - result += "Main Assignments:\n" - for eq in self.main_assignments: - result += f"\t{eq}\n" - return result - - def __iter__(self): - return self.all_assignments.__iter__() - - @property - def main_assignments_dict(self): - return {a.lhs: a.rhs for a in self.main_assignments} - - @property - def subexpressions_dict(self): - return {a.lhs: a.rhs for a in self.subexpressions} - - def set_main_assignments_from_dict(self, main_assignments_dict): - self.main_assignments = [Assignment(k, v) - for k, v in main_assignments_dict.items()] - - def set_sub_expressions_from_dict(self, sub_expressions_dict): - self.subexpressions = [Assignment(k, v) - for k, v in sub_expressions_dict.items()] - - def find(self, *args, **kwargs): - return set.union( - *[a.find(*args, **kwargs) for a in self.all_assignments] - ) - - def match(self, *args, **kwargs): - rtn = {} - for a in self.all_assignments: - partial_result = a.match(*args, **kwargs) - if partial_result: - rtn.update(partial_result) - return rtn - - def subs(self, *args, **kwargs): - return AssignmentCollection( - main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments], - subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions] - ) - - def replace(self, *args, **kwargs): - return AssignmentCollection( - main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments], - subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions] - ) - - def __eq__(self, other): - return set(self.all_assignments) == set(other.all_assignments) - - def __bool__(self): - return bool(self.all_assignments) - - -class SymbolGen: - """Default symbol generator producing number symbols ζ_0, ζ_1, ...""" - - def __init__(self, symbol="xi", dtype=None, ctr=0): - self._ctr = ctr - self._symbol = symbol - self._dtype = dtype - - def __iter__(self): - return self - - def __next__(self): - name = f"{self._symbol}_{self._ctr}" - self._ctr += 1 - if self._dtype is not None: - return TypedSymbol(name, self._dtype) - return sp.Symbol(name) - - -def get_dummy_symbol(dtype='bool'): - return TypedSymbol(f'dummy{uuid.uuid4().hex}', create_type(dtype)) class ConditionalFieldAccess(sp.Function): @@ -618,6 +30,8 @@ class ConditionalFieldAccess(sp.Function): def generic_visit(term, visitor): + from pystencils import AssignmentCollection, Assignment + if isinstance(term, AssignmentCollection): new_main_assignments = generic_visit(term.main_assignments, visitor) new_subexpressions = generic_visit(term.subexpressions, visitor) diff --git a/src/pystencils/sympyextensions/fast_approximation.py b/src/pystencils/sympyextensions/fast_approximation.py index 9088348fb3ff45d36d2f02d0ac7c2244a3d51c03..d9656025e7c69217e2f82eeacdea8c1fe872db8a 100644 --- a/src/pystencils/sympyextensions/fast_approximation.py +++ b/src/pystencils/sympyextensions/fast_approximation.py @@ -2,7 +2,8 @@ from typing import List, Union import sympy as sp -from pystencils.sympyextensions import AssignmentCollection, Assignment +from ..assignment import Assignment +from ..simp import AssignmentCollection # noinspection PyPep8Naming diff --git a/src/pystencils/sympyextensions/math.py b/src/pystencils/sympyextensions/math.py index a2df9458e5038e8ab088f223058fa6f77d893593..1a006efe6cb95ec7196476387f49094a1a8ee9c5 100644 --- a/src/pystencils/sympyextensions/math.py +++ b/src/pystencils/sympyextensions/math.py @@ -10,7 +10,7 @@ from sympy import PolynomialError from sympy.functions import Abs from sympy.core.numbers import Zero -from .astnodes import Assignment +from ..assignment import Assignment from .typed_sympy import CastFunc, FieldPointerSymbol from ..types import PsPointerType, PsVectorType diff --git a/tests/nbackend/kernelcreation/test_domain_kernels.py b/tests/nbackend/kernelcreation/test_domain_kernels.py index 9ce2f661d840641d28774134070fc7050e90e6d1..c9cc81abbe988be4d9c92dc1d833260f38691433 100644 --- a/tests/nbackend/kernelcreation/test_domain_kernels.py +++ b/tests/nbackend/kernelcreation/test_domain_kernels.py @@ -2,7 +2,7 @@ import sympy as sp import numpy as np from pystencils import fields, Field, AssignmentCollection -from pystencils.sympyextensions.astnodes import assignment_from_stencil +from pystencils.assignment import assignment_from_stencil from pystencils.kernelcreation import create_kernel diff --git a/tests/symbolics/test_address_of.py b/tests/symbolics/test_address_of.py index da11ecbe5374b95801f2de027b4db4df9e2fa04d..99f33ddbdfa7054bf5f27c08848640ee03f64555 100644 --- a/tests/symbolics/test_address_of.py +++ b/tests/symbolics/test_address_of.py @@ -6,7 +6,7 @@ import pystencils from pystencils.types import PsPointerType, create_type from pystencils.sympyextensions.pointers import AddressOf from pystencils.sympyextensions.typed_sympy import CastFunc -from pystencils.sympyextensions import sympy_cse +from pystencils.simp import sympy_cse import sympy as sp diff --git a/tests/symbolics/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py index bd384a95948511ede2d65222b69a81479c717a30..e18ffc56a4b0a95c30ff2e9e2d4affa5567654ac 100644 --- a/tests/symbolics/test_conditional_field_access.py +++ b/tests/symbolics/test_conditional_field_access.py @@ -16,7 +16,7 @@ import sympy as sp import pystencils as ps from pystencils import Field, x_vector from pystencils.sympyextensions.astnodes import ConditionalFieldAccess -from pystencils.sympyextensions import sympy_cse +from pystencils.simp import sympy_cse def add_fixed_constant_boundary_handling(assignments, with_cse): diff --git a/tests/test_fvm.py b/tests/test_fvm.py index e10874cbdf27abe161510e57242f1232f60a6d53..0b103e52554f767a0948509e8cae084c1af9e124 100644 --- a/tests/test_fvm.py +++ b/tests/test_fvm.py @@ -3,7 +3,7 @@ import pystencils as ps import numpy as np import pytest from itertools import product -from pystencils.sympyextensions.rng import random_symbol +from pystencils.rng import random_symbol from pystencils.sympyextensions.astnodes import SympyAssignment from pystencils.node_collection import NodeCollection diff --git a/tests/test_random.py b/tests/test_random.py index 6c3a888db86ff7e3b3ed100961643e3cdf36231b..63cd61494b675acbc1675277ff9eb9b1f57ddb3e 100644 --- a/tests/test_random.py +++ b/tests/test_random.py @@ -5,7 +5,7 @@ import pytest import pystencils as ps from pystencils.sympyextensions.astnodes import SympyAssignment from pystencils.node_collection import NodeCollection -from pystencils.sympyextensions.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol +from pystencils.rng import PhiloxFourFloats, PhiloxTwoDoubles, AESNIFourFloats, AESNITwoDoubles, random_symbol from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.cpu.cpujit import get_compiler_config from pystencils.typing import TypedSymbol