From c21d4621ea558dacd8e9a8395ed0d7114006a547 Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Tue, 29 Aug 2023 10:28:59 +0200 Subject: [PATCH] Distinguish between SymPy and pystencils Assignement better --- pystencils/__init__.py | 4 +- pystencils/assignment.py | 13 ++- pystencils/astnodes.py | 23 ++++- pystencils/backends/cbackend.py | 18 ++-- pystencils/cpu/kernelcreation.py | 2 +- pystencils/kernelcreation.py | 15 ++-- pystencils/node_collection.py | 83 +++++++++---------- pystencils/transformations.py | 2 +- pystencils/typing/leaf_typing.py | 13 ++- pystencils/typing/transformations.py | 10 ++- pystencils_tests/test_augmented_assignment.py | 35 ++++++++ pystencils_tests/test_modulo.py | 2 +- 12 files changed, 139 insertions(+), 81 deletions(-) create mode 100644 pystencils_tests/test_augmented_assignment.py diff --git a/pystencils/__init__.py b/pystencils/__init__.py index 92fdda9c5..0003a8b9a 100644 --- a/pystencils/__init__.py +++ b/pystencils/__init__.py @@ -2,7 +2,7 @@ from .enums import Backend, Target from . import fd from . import stencil as stencil -from .assignment import Assignment, assignment_from_stencil +from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from .typing.typed_sympy import TypedSymbol from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .field import Field, FieldType, fields @@ -24,7 +24,7 @@ __all__ = ['Field', 'FieldType', 'fields', 'Target', 'Backend', 'show_code', 'to_dot', 'get_code_obj', 'get_code_str', 'AssignmentCollection', - 'Assignment', + 'Assignment', 'AddAugmentedAssignment', 'assignment_from_stencil', 'SymbolCreator', 'create_data_handling', diff --git a/pystencils/assignment.py b/pystencils/assignment.py index c3ae4b436..d0e849954 100644 --- a/pystencils/assignment.py +++ b/pystencils/assignment.py @@ -1,20 +1,22 @@ import numpy as np import sympy as sp -from sympy.codegen.ast import Assignment +from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment from sympy.printing.latex import LatexPrinter -__all__ = ['Assignment', 'assignment_from_stencil'] +__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil'] def print_assignment_latex(printer, expr): + binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else '' """sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer""" printed_lhs = printer.doprint(expr.lhs) printed_rhs = printer.doprint(expr.rhs) - return fr"{printed_lhs} \leftarrow {printed_rhs}" + return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}" def assignment_str(assignment): - return fr"{assignment.lhs} ↠{assignment.rhs}" + op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else 'â†' + return fr"{assignment.lhs} {op} {assignment.rhs}" _old_new = sp.codegen.ast.Assignment.__new__ @@ -32,6 +34,9 @@ Assignment.__str__ = assignment_str Assignment.__new__ = _Assignment__new__ LatexPrinter._print_Assignment = print_assignment_latex +AugmentedAssignment.__str__ = assignment_str +LatexPrinter._print_AugmentedAssignment = print_assignment_latex + sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self)) diff --git a/pystencils/astnodes.py b/pystencils/astnodes.py index c9d66ae26..0c49160c8 100644 --- a/pystencils/astnodes.py +++ b/pystencils/astnodes.py @@ -561,10 +561,10 @@ class SympyAssignment(Node): def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False): super(SympyAssignment, self).__init__(parent=None) self._lhs_symbol = sp.sympify(lhs_symbol) - self.rhs = sp.sympify(rhs_expr) + self._rhs = sp.sympify(rhs_expr) self._is_const = is_const self._is_declaration = self.__is_declaration() - self.use_auto = use_auto + self._use_auto = use_auto def __is_declaration(self): from pystencils.typing import CastFunc @@ -578,15 +578,28 @@ class SympyAssignment(Node): def lhs(self): return self._lhs_symbol + @property + def rhs(self): + return self._rhs + @lhs.setter def lhs(self, new_value): self._lhs_symbol = new_value self._is_declaration = self.__is_declaration() + @rhs.setter + def rhs(self, new_rhs_expr): + self._rhs = new_rhs_expr + def subs(self, subs_dict): self.lhs = fast_subs(self.lhs, subs_dict) self.rhs = fast_subs(self.rhs, subs_dict) + def fast_subs(self, subs_dict, skip=None): + self.lhs = fast_subs(self.lhs, subs_dict, skip) + self.rhs = fast_subs(self.rhs, subs_dict, skip) + return self + def optimize(self, optimizations): try: from sympy.codegen.rewriting import optimize @@ -596,7 +609,7 @@ class SympyAssignment(Node): @property def args(self): - return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)] + return [self._lhs_symbol, self.rhs] @property def symbols_defined(self): @@ -627,6 +640,10 @@ class SympyAssignment(Node): def is_const(self): return self._is_const + @property + def use_auto(self): + return self._use_auto + def replace(self, child, replacement): if child == self.lhs: replacement.parent = self diff --git a/pystencils/backends/cbackend.py b/pystencils/backends/cbackend.py index 7665f9dfc..cc1de06c0 100644 --- a/pystencils/backends/cbackend.py +++ b/pystencils/backends/cbackend.py @@ -262,19 +262,17 @@ class CBackend: return f"{prefix}{loop_str}\n{self._print(node.body)}" def _print_SympyAssignment(self, node): + printed_lhs = self.sympy_printer.doprint(node.lhs) + printed_rhs = self.sympy_printer.doprint(node.rhs) + if node.is_declaration: if node.use_auto: - data_type = 'auto ' + data_type = 'auto' else: + data_type = self._print(node.lhs.dtype).replace(' const', '') if node.is_const: - prefix = 'const ' - else: - prefix = '' - data_type = prefix + self._print(node.lhs.dtype).replace(' const', '') + " " - - return "%s%s = %s;" % (data_type, - self.sympy_printer.doprint(node.lhs), - self.sympy_printer.doprint(node.rhs)) + data_type = f'const {data_type}' + return f"{data_type} {printed_lhs} = {printed_rhs};" else: lhs_type = get_type_of_expression(node.lhs) # TOOD: this should have been typed printed_mask = "" @@ -350,7 +348,7 @@ class CBackend: code += f"\nif ({flushcond}) {{\n\t{code2}\n}} else {{\n\t{code1}\n}}" return pre_code + code else: - return f"{self.sympy_printer.doprint(node.lhs)} = {self.sympy_printer.doprint(node.rhs)};" + return f"{printed_lhs} = {printed_rhs};" def _print_NontemporalFence(self, _): if 'streamFence' in self._vector_instruction_set: diff --git a/pystencils/cpu/kernelcreation.py b/pystencils/cpu/kernelcreation.py index 4cf0955a5..c93d5ed72 100644 --- a/pystencils/cpu/kernelcreation.py +++ b/pystencils/cpu/kernelcreation.py @@ -18,7 +18,7 @@ from pystencils.transformations import ( resolve_field_accesses, split_inner_loop) -def create_kernel(assignments: Union[AssignmentCollection, NodeCollection], +def create_kernel(assignments: Union[NodeCollection], config: CreateKernelConfig) -> KernelFunction: """Creates an abstract syntax tree for a kernel function, by taking a list of update rules. diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 6e13bcfbd..385f42d2f 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -5,7 +5,7 @@ from typing import Union, List import sympy as sp from pystencils.config import CreateKernelConfig -from pystencils.assignment import Assignment +from pystencils.assignment import Assignment, AddAugmentedAssignment from pystencils.astnodes import Node, Block, Conditional, LoopOverCoordinate, SympyAssignment from pystencils.cpu.vectorization import vectorize from pystencils.enums import Target, Backend @@ -19,7 +19,10 @@ from pystencils.transformations import ( loop_blocking, move_constants_before_loop, remove_conditionals_in_staggered_kernel) -def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCollection, List[Node], NodeCollection], *, +def create_kernel(assignments: Union[Assignment, List[Assignment], + AddAugmentedAssignment, List[AddAugmentedAssignment], + AssignmentCollection, List[Node], NodeCollection], + *, config: CreateKernelConfig = None, **kwargs): """ Creates abstract syntax tree (AST) of kernel, using a list of update equations. @@ -59,7 +62,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol setattr(config, k, v) # ---- Normalizing parameters - if isinstance(assignments, Assignment): + if isinstance(assignments, (Assignment, AddAugmentedAssignment)): assignments = [assignments] assert assignments, "Assignments must not be empty!" if isinstance(assignments, list): @@ -86,13 +89,13 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig): """ - Creates abstract syntax tree (AST) of kernel, using a list of update equations. + Creates abstract syntax tree (AST) of kernel, using a NodeCollection. Note that `create_domain_kernel` is a lower level function which shoul be accessed by not providing `index_fields` to create_kernel Args: - assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection` + assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed config: CreateKernelConfig which includes the needed configuration Returns: @@ -187,7 +190,7 @@ def create_indexed_kernel(assignments: NodeCollection, *, config: CreateKernelCo to create_kernel Args: - assignments: can be a single assignment, sequence of assignments or an `AssignmentCollection` + assignments: `pystencils.node_collection.NodeCollection` containing all assignements and nodes to be processed config: CreateKernelConfig which includes the needed configuration Returns: diff --git a/pystencils/node_collection.py b/pystencils/node_collection.py index 227e1a10d..352406566 100644 --- a/pystencils/node_collection.py +++ b/pystencils/node_collection.py @@ -1,8 +1,9 @@ -from typing import List, Union +from collections.abc import Iterable +from typing import Any, Dict, List, Union, Optional, Set import sympy import sympy as sp -from sympy.codegen import Assignment +from sympy.codegen.ast import Assignment, AddAugmentedAssignment from sympy.codegen.rewriting import ReplaceOptim, optimize from pystencils.astnodes import Block, Node, SympyAssignment @@ -12,33 +13,32 @@ from pystencils.simp import AssignmentCollection class NodeCollection: - def __init__(self, assignments: List[Union[Node, Assignment]]): - self.all_assignments = assignments - - if all((isinstance(a, Assignment) for a in assignments)): - self.is_Nodes = False - self.is_Assignments = True - elif all((isinstance(n, Node) for n in assignments)): - self.is_Nodes = True - self.is_Assignments = False - else: - raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" ' - f'or a list of "pystencils.astnodes.Node') + def __init__(self, assignments: List[Union[Node, Assignment]], + simplification_hints: Optional[Dict[str, Any]] = None, + bound_fields: Set[sp.Symbol] = None, rhs_fields: Set[sp.Symbol] = None): + nodes = list() + assignments = [assignments, ] if not isinstance(assignments, Iterable) else assignments + for assignment in assignments: + if isinstance(assignment, Assignment): + nodes.append(SympyAssignment(assignment.lhs, assignment.rhs)) + elif isinstance(assignment, AddAugmentedAssignment): + nodes.append(SympyAssignment(assignment.lhs, assignment.lhs + assignment.rhs)) + elif isinstance(assignment, Node): + nodes.append(assignment) + else: + raise ValueError(f"Unknown node in the AssignmentCollection: {assignment}") - self.simplification_hints = {} + self.all_assignments = nodes + self.simplification_hints = simplification_hints if simplification_hints else {} + self.bound_fields = bound_fields if bound_fields else {} + self.rhs_fields = rhs_fields if rhs_fields else {} @staticmethod def from_assignment_collection(assignment_collection: AssignmentCollection): - nodes = list() - for assignemt in assignment_collection.all_assignments: - if isinstance(assignemt, Assignment): - nodes.append(SympyAssignment(assignemt.lhs, assignemt.rhs)) - elif isinstance(assignemt, Node): - nodes.append(assignemt) - else: - raise ValueError(f"Unknown node in the AssignmentCollection: {assignemt}") - - return NodeCollection(nodes) + return NodeCollection(assignments=assignment_collection.all_assignments, + simplification_hints=assignment_collection.simplification_hints, + bound_fields=assignment_collection.bound_fields, + rhs_fields=assignment_collection.rhs_fields) def evaluate_terms(self): evaluate_constant_terms = ReplaceOptim( @@ -54,21 +54,20 @@ class NodeCollection: ) sympy_optimisations = [evaluate_constant_terms, evaluate_pow] - if self.is_Nodes: - def visitor(node): - if isinstance(node, CustomCodeNode): - return node - elif isinstance(node, Block): - return node.func([visitor(child) for child in node.args]) - elif isinstance(node, Node): - return node.func(*[visitor(child) for child in node.args]) - elif isinstance(node, sympy.Basic): - return optimize(node, sympy_optimisations) - else: - raise NotImplementedError(f'{node} {type(node)} has no valid visitor') + def visitor(node): + if isinstance(node, CustomCodeNode): + return node + elif isinstance(node, Block): + return node.func([visitor(child) for child in node.args]) + elif isinstance(node, SympyAssignment): + new_lhs = visitor(node.lhs) + new_rhs = visitor(node.rhs) + return node.func(new_lhs, new_rhs, node.is_const, node.use_auto) + elif isinstance(node, Node): + return node.func(*[visitor(child) for child in node.args]) + elif isinstance(node, sympy.Basic): + return optimize(node, sympy_optimisations) + else: + raise NotImplementedError(f'{node} {type(node)} has no valid visitor') - self.all_assignments = [visitor(assignment) for assignment in self.all_assignments] - else: - self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations)) - if hasattr(a, 'lhs') - else a for a in self.all_assignments] + self.all_assignments = [visitor(assignment) for assignment in self.all_assignments] diff --git a/pystencils/transformations.py b/pystencils/transformations.py index e07d871e9..5cde907b5 100644 --- a/pystencils/transformations.py +++ b/pystencils/transformations.py @@ -520,7 +520,7 @@ def resolve_field_accesses(ast_node, read_only_field_names=None, coord_dict = create_coordinate_dict(group) new_ptr, offset = create_intermediate_base_pointer(field_access, coord_dict, last_pointer) if new_ptr not in enclosing_block.symbols_defined: - new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False) + new_assignment = ast.SympyAssignment(new_ptr, last_pointer + offset, is_const=False, use_auto=False) enclosing_block.insert_before(new_assignment, sympy_assignment) last_pointer = new_ptr diff --git a/pystencils/typing/leaf_typing.py b/pystencils/typing/leaf_typing.py index ecb82bab8..6c30a6abf 100644 --- a/pystencils/typing/leaf_typing.py +++ b/pystencils/typing/leaf_typing.py @@ -10,7 +10,6 @@ from sympy.core.relational import Relational from sympy.functions.elementary.piecewise import ExprCondPair from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction -from sympy.codegen import Assignment from sympy.logic.boolalg import BooleanFunction from sympy.logic.boolalg import BooleanAtom @@ -51,7 +50,7 @@ class TypeAdder: def visit(self, obj): if isinstance(obj, (list, tuple)): return [self.visit(e) for e in obj] - if isinstance(obj, (sp.Eq, ast.SympyAssignment, Assignment)): + if isinstance(obj, ast.SympyAssignment): return self.process_assignment(obj) elif isinstance(obj, ast.Conditional): condition, condition_type = self.figure_out_type(obj.condition_expr) @@ -67,7 +66,7 @@ class TypeAdder: else: raise ValueError("Invalid object in kernel " + str(type(obj))) - def process_assignment(self, assignment: Union[sp.Eq, ast.SympyAssignment, Assignment]) -> ast.SympyAssignment: + def process_assignment(self, assignment: ast.SympyAssignment) -> ast.SympyAssignment: # for checks it is crucial to process rhs before lhs to catch e.g. a = a + 1 new_rhs, rhs_type = self.figure_out_type(assignment.rhs) @@ -81,11 +80,11 @@ class TypeAdder: assert isinstance(new_lhs, (Field.Access, TypedSymbol)) if lhs_type != rhs_type: - logging.warning(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype ' - f'rhs: "{new_rhs}" of type "{rhs_type}".') - return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type)) + logging.debug(f'Lhs"{new_lhs} of type "{lhs_type}" is assigned with a different datatype ' + f'rhs: "{new_rhs}" of type "{rhs_type}".') + return ast.SympyAssignment(new_lhs, CastFunc(new_rhs, lhs_type), assignment.is_const, assignment.use_auto) else: - return ast.SympyAssignment(new_lhs, new_rhs) + return ast.SympyAssignment(new_lhs, new_rhs, assignment.is_const, assignment.use_auto) # Type System Specification # - Defined Types: TypedSymbol, Field, Field.Access, ...? diff --git a/pystencils/typing/transformations.py b/pystencils/typing/transformations.py index 74ecf19f1..43e69eb28 100644 --- a/pystencils/typing/transformations.py +++ b/pystencils/typing/transformations.py @@ -1,17 +1,19 @@ from typing import List +from pystencils.astnodes import Node from pystencils.config import CreateKernelConfig from pystencils.typing.leaf_typing import TypeAdder -from sympy.codegen import Assignment -def add_types(eqs: List[Assignment], config: CreateKernelConfig): +def add_types(node_list: List[Node], config: CreateKernelConfig): """Traverses AST and replaces every :class:`sympy.Symbol` by a :class:`pystencils.typedsymbol.TypedSymbol`. + The AST needs to be a pystencils AST. Thus, in the list of nodes every entry must be inherited from + `pystencils.astnodes.Node` Additionally returns sets of all fields which are read/written Args: - eqs: list of equations + node_list: List of pystencils Nodes. config: CreateKernelConfig Returns: @@ -22,4 +24,4 @@ def add_types(eqs: List[Assignment], config: CreateKernelConfig): default_number_float=config.default_number_float, default_number_int=config.default_number_int) - return check.visit(eqs) + return check.visit(node_list) diff --git a/pystencils_tests/test_augmented_assignment.py b/pystencils_tests/test_augmented_assignment.py new file mode 100644 index 000000000..43fa7e8e1 --- /dev/null +++ b/pystencils_tests/test_augmented_assignment.py @@ -0,0 +1,35 @@ +import pytest +import pystencils as ps + + +@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) +def test_add_augmented_assignment(target): + if target == ps.Target.GPU: + pytest.importorskip("cupy") + + domain_size = (5, 5) + dh = ps.create_data_handling(domain_size=domain_size, periodicity=True, default_target=target) + + f = dh.add_array("f", values_per_cell=1) + dh.fill(f.name, 0.0) + + g = dh.add_array("g", values_per_cell=1) + dh.fill(g.name, 1.0) + + up = ps.AddAugmentedAssignment(f.center, g.center) + + config = ps.CreateKernelConfig(target=dh.default_target) + ast = ps.create_kernel(up, config=config) + + kernel = ast.compile() + for i in range(10): + dh.run_kernel(kernel) + + if target == ps.Target.GPU: + dh.all_to_cpu() + + result = dh.gather_array(f.name) + + for x in range(domain_size[0]): + for y in range(domain_size[1]): + assert result[x, y] == 10 diff --git a/pystencils_tests/test_modulo.py b/pystencils_tests/test_modulo.py index 959daddb9..5a32acf5c 100644 --- a/pystencils_tests/test_modulo.py +++ b/pystencils_tests/test_modulo.py @@ -10,7 +10,7 @@ from pystencils.astnodes import LoopOverCoordinate, Conditional, Block, SympyAss def test_mod(target, iteration_slice): if target == ps.Target.GPU: pytest.importorskip("cupy") - dh = ps.create_data_handling(domain_size=(5, 5), periodicity=True, default_target=ps.Target.CPU) + dh = ps.create_data_handling(domain_size=(5, 5), periodicity=True, default_target=target) loop_ctrs = [LoopOverCoordinate.get_loop_counter_symbol(i) for i in range(dh.dim)] cond = [sp.Eq(sp.Mod(loop_ctrs[i], 2), 1) for i in range(dh.dim)] -- GitLab