diff --git a/pystencils/kernelcreation.py b/pystencils/kernelcreation.py index 16876bd0ee5ec5a665db7e9777bfa29056cd74fe..e0c635e06220f10198383e3c1fbbe0de7561280d 100644 --- a/pystencils/kernelcreation.py +++ b/pystencils/kernelcreation.py @@ -1,4 +1,5 @@ from types import MappingProxyType +from itertools import combinations import sympy as sp @@ -280,14 +281,18 @@ def create_staggered_kernel(assignments, target='cpu', gpu_exclusive_conditions= if gpu_exclusive_conditions: outer_assignment = None - for assignment in assignments: - direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]] - assignment = SympyAssignment(assignment.lhs, assignment.rhs) - outer_assignment = Conditional(condition(direction), Block([assignment]), outer_assignment) + conditions = {direction: condition(direction) for direction in stencil} + for num_conditions in range(len(stencil)): + for combination in combinations(conditions.values(), num_conditions): + for assignment in assignments: + direction = stencil[assignment.lhs.index[0]] + if conditions[direction] in combination: + assignment = SympyAssignment(assignment.lhs, assignment.rhs) + outer_assignment = Conditional(sp.And(*combination), Block([assignment]), outer_assignment) inner_assignment = [] for assignment in assignments: - direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]] + direction = stencil[assignment.lhs.index[0]] inner_assignment.append(SympyAssignment(assignment.lhs, assignment.rhs)) last_conditional = Conditional(sp.And(*[condition(d) for d in stencil]), Block(inner_assignment), outer_assignment) @@ -303,7 +308,7 @@ def create_staggered_kernel(assignments, target='cpu', gpu_exclusive_conditions= return ast for assignment in assignments: - direction = assignment.lhs.field.staggered_stencil[assignment.lhs.index[0]] + direction = stencil[assignment.lhs.index[0]] sp_assignments = [s for s in subexpressions if not hasattr(s, 'lhs')] + \ [SympyAssignment(s.lhs, s.rhs) for s in subexpressions if hasattr(s, 'lhs')] + \ [SympyAssignment(assignment.lhs, assignment.rhs)]