Skip to content
Snippets Groups Projects
Commit 110c9094 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix constraint check

parent 01221072
Branches
Tags
1 merge request!292Rebase of pystencils Type System
import itertools import itertools
import logging
import warnings import warnings
from typing import Union, List from typing import Union, List
...@@ -65,14 +64,18 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol ...@@ -65,14 +64,18 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
assignments = [assignments] assignments = [assignments]
assert assignments, "Assignments must not be empty!" assert assignments, "Assignments must not be empty!"
if isinstance(assignments, list): if isinstance(assignments, list):
if all((isinstance(a, Assignment) for a in assignments)): assignments = NodeCollection(assignments)
assignments = AssignmentCollection(assignments) elif isinstance(assignments, AssignmentCollection):
elif all((isinstance(n, Node) for n in assignments)): # TODO check and doku
assignments = NodeCollection(assignments) # --- applying first default simplifications
logging.warning('Using Nodes is experimental and not fully tested. Double check your generated code!') try:
else: if config.default_assignment_simplifications:
raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" ' simplification = create_simplification_strategy()
f'or a list of "pystencils.astnodes.Node') assignments = simplification(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
f"AssignmentCollection due to the following problem :{e}")
assignments = NodeCollection(assignments.all_assignments)
if config.index_fields: if config.index_fields:
return create_indexed_kernel(assignments, config=config) return create_indexed_kernel(assignments, config=config)
...@@ -80,7 +83,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol ...@@ -80,7 +83,7 @@ def create_kernel(assignments: Union[Assignment, List[Assignment], AssignmentCol
return create_domain_kernel(assignments, config=config) return create_domain_kernel(assignments, config=config)
def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection], *, config: CreateKernelConfig): def create_domain_kernel(assignments: NodeCollection, *, config: CreateKernelConfig):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a list of update equations.
...@@ -96,10 +99,11 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection ...@@ -96,10 +99,11 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection
>>> import pystencils as ps >>> import pystencils as ps
>>> import numpy as np >>> import numpy as np
>>> from pystencils.kernelcreation import create_domain_kernel >>> from pystencils.kernelcreation import create_domain_kernel
>>> from pystencils.node_collection import NodeCollection
>>> s, d = ps.fields('s, d: [2D]') >>> s, d = ps.fields('s, d: [2D]')
>>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0]) >>> assignment = ps.Assignment(d[0,0], s[0, 1] + s[0, -1] + s[1, 0] + s[-1, 0])
>>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True) >>> kernel_config = ps.CreateKernelConfig(cpu_openmp=True)
>>> kernel_ast = create_domain_kernel(ps.AssignmentCollection([assignment]), config=kernel_config) >>> kernel_ast = create_domain_kernel(NodeCollection([assignment]), config=kernel_config)
>>> kernel = kernel_ast.compile() >>> kernel = kernel_ast.compile()
>>> d_arr = np.zeros([5, 5]) >>> d_arr = np.zeros([5, 5])
>>> kernel(d=d_arr, s=np.ones([5, 5])) >>> kernel(d=d_arr, s=np.ones([5, 5]))
...@@ -110,21 +114,8 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection ...@@ -110,21 +114,8 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection
[0., 4., 4., 4., 0.], [0., 4., 4., 4., 0.],
[0., 0., 0., 0., 0.]]) [0., 0., 0., 0., 0.]])
""" """
# --- applying first default simplifications
if isinstance(assignments, AssignmentCollection):
try:
if config.default_assignment_simplifications and isinstance(assignments, AssignmentCollection):
simplification = create_simplification_strategy()
assignments = simplification(assignments)
except Exception as e:
warnings.warn(f"It was not possible to apply the default pystencils optimisations to the "
f"AssignmentCollection due to the following problem :{e}")
assignments.evaluate_terms()
# --- eval # --- eval
# TODO split apply_sympy_optimisations and do the eval here assignments.evaluate_terms()
# FUTURE WORK from here we shouldn't NEED sympy # FUTURE WORK from here we shouldn't NEED sympy
# --- check constrains # --- check constrains
...@@ -132,12 +123,8 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection ...@@ -132,12 +123,8 @@ def create_domain_kernel(assignments: Union[AssignmentCollection, NodeCollection
check_double_write_condition=not config.allow_double_writes) check_double_write_condition=not config.allow_double_writes)
check.visit(assignments) check.visit(assignments)
if isinstance(assignments, AssignmentCollection): assignments.bound_fields = check.fields_written
assert assignments.bound_fields == check.fields_written, f'WTF' assignments.rhs_fields = check.fields_read
assert assignments.rhs_fields == check.fields_read, f'WTF'
else:
assignments.bound_fields = check.fields_written
assignments.rhs_fields = check.fields_read
# ---- Creating ast # ---- Creating ast
ast = None ast = None
......
from typing import List import logging
from typing import List, Union
import sympy as sp
from sympy.codegen import Assignment
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.astnodes import Node from pystencils.astnodes import Node
from pystencils.functions import DivFunc
# TODO ABC for NodeCollection and AssignmentCollection
class NodeCollection: class NodeCollection:
def __init__(self, nodes: List[Node]): def __init__(self, assignments: List[Union[Node, Assignment]]):
self.nodes = nodes self.all_assignments = assignments
self.bound_fields = None
self.rhs_fields = None if all((isinstance(a, Assignment) for a in assignments)):
self.is_Nodes = False
self.is_Assignments = True
elif all((isinstance(n, Node) for n in assignments)):
self.is_Nodes = True
self.is_Assignments = False
logging.warning('Using Nodes is experimental and not fully tested. Double check your generated code!')
else:
raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" '
f'or a list of "pystencils.astnodes.Node')
self.simplification_hints = () self.simplification_hints = ()
@property def evaluate_terms(self):
def all_assignments(self):
return self.nodes # There is no visitor implemented now so working with nodes does not work
if self.is_Nodes:
return
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf())
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: (
sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
DivFunc(sp.Integer(1), sp.Mul(*([p.base] * -p.exp), evaluate=False))
))
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in self.all_assignments]
...@@ -3,11 +3,9 @@ from copy import copy ...@@ -3,11 +3,9 @@ from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp import sympy as sp
from sympy.codegen.rewriting import ReplaceOptim, optimize
import pystencils import pystencils
from pystencils.assignment import Assignment from pystencils.assignment import Assignment
from pystencils.functions import DivFunc
from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs) from pystencils.simp.simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from pystencils.sympyextensions import count_operations, fast_subs from pystencils.sympyextensions import count_operations, fast_subs
...@@ -341,8 +339,10 @@ class AssignmentCollection: ...@@ -341,8 +339,10 @@ class AssignmentCollection:
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments] new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions) return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep: Set[sp.Symbol] = set()) -> 'AssignmentCollection': def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted.""" """Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0: if len(self.subexpressions) == 0:
return self.copy() return self.copy()
...@@ -365,30 +365,6 @@ class AssignmentCollection: ...@@ -365,30 +365,6 @@ class AssignmentCollection:
new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments] new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
return self.copy(new_assignment, kept_subexpressions) return self.copy(new_assignment, kept_subexpressions)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf())
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: (
sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
DivFunc(sp.Integer(1), sp.Mul(*([p.base] * -p.exp), evaluate=False))
))
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
self.subexpressions = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in self.subexpressions]
self.main_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in self.main_assignments]
# ----------------------------------------- Display and Printing ------------------------------------------------- # ----------------------------------------- Display and Printing -------------------------------------------------
def _repr_html_(self): def _repr_html_(self):
......
...@@ -177,12 +177,7 @@ def get_type_of_expression(expr, ...@@ -177,12 +177,7 @@ def get_type_of_expression(expr,
else: else:
forbid_collation_to_complex = expr.is_real is True forbid_collation_to_complex = expr.is_real is True
forbid_collation_to_float = expr.is_integer is True forbid_collation_to_float = expr.is_integer is True
return collate_types( return collate_types(types)
types,
forbid_collation_to_complex=forbid_collation_to_complex,
forbid_collation_to_float=forbid_collation_to_float,
default_float_type=default_float_type,
default_int_type=default_int_type)
else: else:
if expr.is_integer: if expr.is_integer:
return create_type(default_int_type) return create_type(default_int_type)
......
import numpy as np import numpy as np
import pytest import pytest
import pystencils
import sympy as sp import sympy as sp
from pystencils import Assignment, Field, create_kernel, fields from pystencils import Assignment, Field, create_kernel, fields
...@@ -104,13 +106,20 @@ def test_loop_independence_checks(): ...@@ -104,13 +106,20 @@ def test_loop_independence_checks():
Assignment(g[0, 0], f[1, 0])]) Assignment(g[0, 0], f[1, 0])])
assert 'Field g is written at two different locations' in str(e.value) assert 'Field g is written at two different locations' in str(e.value)
# This is allowed - because only one element of g is accessed # This is not allowed - because this is not SSA (it can be overwritten with allow_double_writes)
with pytest.raises(ValueError) as e:
create_kernel([Assignment(g[0, 2], f[0, 1]),
Assignment(g[0, 2], 2 * g[0, 2])])
# This is allowed - because allow_double_writes is True now
create_kernel([Assignment(g[0, 2], f[0, 1]), create_kernel([Assignment(g[0, 2], f[0, 1]),
Assignment(g[0, 2], 2 * g[0, 2])]) Assignment(g[0, 2], 2 * g[0, 2])],
config=pystencils.CreateKernelConfig(allow_double_writes=True))
create_kernel([Assignment(v[0, 2](1), f[0, 1]), with pytest.raises(ValueError) as e:
Assignment(v[0, 1](0), 4), create_kernel([Assignment(v[0, 2](1), f[0, 1]),
Assignment(v[0, 2](1), 2 * v[0, 2](1))]) Assignment(v[0, 1](0), 4),
Assignment(v[0, 2](1), 2 * v[0, 2](1))])
with pytest.raises(ValueError) as e: with pytest.raises(ValueError) as e:
create_kernel([Assignment(g[0, 1], 3), create_kernel([Assignment(g[0, 1], 3),
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment