diff --git a/src/pystencils_autodiff/autodiff.py b/src/pystencils_autodiff/autodiff.py index 984eaaf808102faebbbc96c6d264003a109ea214..9447df9c311d14c3feec3b62af11af9ef297aa10 100644 --- a/src/pystencils_autodiff/autodiff.py +++ b/src/pystencils_autodiff/autodiff.py @@ -20,6 +20,27 @@ class DiffModes(str, Enum): TF_MAD = 'transposed-forward' +def _has_exclusive_writes(assignment_collection): + """ + Simple check for exclusive (non-overlapping) writes. + I.e. AssignmentCollection can be executed safely in parallel without caring about race conditions. + No writes on same spatial location (considering all possible shifts). + """ + + assignments = assignment_collection.main_assignments + write_field_accesses = [a.lhs for a in assignments if isinstance(a.lhs, ps.Field.Access)] + + exclusive_writes = set() + for a in write_field_accesses: + + if (a.field, a.index) in exclusive_writes: + return False + else: + exclusive_writes.add((a.field, a.index)) + + return True + + def get_jacobian_of_assignments(assignments, diff_variables): """ Calculates the Jacobian of iterable of assignments wrt. diff_variables @@ -180,7 +201,7 @@ Backward: main_assignments = [a for a in backward_assignments if isinstance(a.lhs, ps.Field.Access)] subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)] backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions) - assert backward_assignments.has_exclusive_writes, "Backward assignments don't have exclusive writes." + \ + assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes." + \ " You should consider using 'transposed-forward' mode for resolving those conflicts" self._forward_assignments = forward_assignments @@ -329,7 +350,7 @@ Backward: subexpressions = [a for a in backward_assignments if not isinstance(a.lhs, ps.Field.Access)] backward_assignments = ps.AssignmentCollection(main_assignments, subexpressions) - assert backward_assignments.has_exclusive_writes, "Backward assignments don't have exclusive writes!" + assert _has_exclusive_writes(backward_assignments), "Backward assignments don't have exclusive writes!" self._backward_assignments = backward_assignments self._backward_field_map = {**diff_read_fields, **diff_write_fields}