From c5f88f1e2a8a4171cadf61ffd8fc37dfc834ad34 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Sat, 18 Nov 2023 00:27:54 +0100
Subject: [PATCH] Fix j-reduction for torque

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/dem.py                      |  2 +-
 src/pairs/ir/apply.py                | 77 +++++++++++++++++++++++++---
 src/pairs/ir/arrays.py               |  2 +-
 src/pairs/ir/features.py             |  2 +-
 src/pairs/ir/lit.py                  |  2 +-
 src/pairs/ir/matrices.py             | 13 ++---
 src/pairs/ir/mutator.py              |  6 +++
 src/pairs/ir/properties.py           |  4 +-
 src/pairs/ir/quaternions.py          | 13 ++---
 src/pairs/ir/scalars.py              | 20 ++++----
 src/pairs/ir/variables.py            |  2 +-
 src/pairs/ir/vectors.py              | 32 ++++++++----
 src/pairs/mapping/funcs.py           |  2 +-
 src/pairs/mapping/keywords.py        |  4 +-
 src/pairs/transformations/devices.py |  2 +-
 15 files changed, 129 insertions(+), 54 deletions(-)

diff --git a/examples/dem.py b/examples/dem.py
index 60f86c0..b8630d2 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -75,7 +75,7 @@ def linear_spring_dashpot(i, j):
     partial_force = fN + fT
 
     apply(force, partial_force)
-    apply(torque, cross(contact_point(i, j) - position[i], partial_force))
+    apply(torque, cross(contact_point(i, j) - position, partial_force))
 
 
 def euler(i):
diff --git a/src/pairs/ir/apply.py b/src/pairs/ir/apply.py
index 6ae09dd..1501fe2 100644
--- a/src/pairs/ir/apply.py
+++ b/src/pairs/ir/apply.py
@@ -3,23 +3,27 @@ from pairs.ir.ast_node import ASTNode
 from pairs.ir.block import pairs_inline
 from pairs.ir.branches import Filter
 from pairs.ir.lit import Lit
-from pairs.ir.properties import Property
+from pairs.ir.mutator import Mutator
+from pairs.ir.properties import Property, PropertyAccess
 from pairs.ir.scalars import ScalarOp
-from pairs.ir.vectors import Vector
+from pairs.ir.vectors import Vector, VectorAccess, VectorOp
 from pairs.ir.types import Types
 from pairs.sim.flags import Flags
 from pairs.sim.lowerable import FinalLowerable, Lowerable
 
 
 class Apply(Lowerable):
-    def __init__(self, sim, prop, expr, j):
+    def __init__(self, sim, prop, expr, i, j):
         assert isinstance(prop, Property), "Apply(): Destination must of Property type."
         assert prop.type() == expr.type(), "Apply(): Property and expression must be of same type."
         assert sim.current_apply_list() is not None, "Apply(): Not used within particle interaction."
         super().__init__(sim)
         self._prop = prop
         self._expr = Lit.cvt(sim, expr)
+        self._i = i
         self._j = j
+        self._expr_i = self.build_expression_with_index(self._expr, self._i)
+        self._expr_j = self.build_expression_with_index(self._expr, self._j) if sim._compute_half else None
         self._reduction_variable = None
         sim.current_apply_list().add(self)
         sim.add_statement(self)
@@ -40,15 +44,76 @@ class Apply(Lowerable):
         return self._reduction_variable
 
     def children(self):
-        return [self._prop, self._expr, self._reduction_variable, self._j]
+        return [self._prop, self._i, self._j] + \
+               [self._reduction_variable] if self._reduction_variable is not None else [] + \
+               [self._expr] if self._expr is not None else [] + \
+               [self._expr_i] if self._expr_i is not None else [] + \
+               [self._expr_j] if self._expr_j is not None else []
+
+    def build_expression_with_index(self, expr, index):
+        return self._build_expression_with_index(expr, index)[0]
+
+    # TODO: This method should comprise all operators and dynamic data types, it would also be
+    # better to provide a better way to implement it such as a Mutator or Visitor
+    def _build_expression_with_index(self, expr, index):
+        if isinstance(expr, (ScalarOp, VectorOp)):
+            new_lhs, changed_lhs = self._build_expression_with_index(expr.lhs, index)
+            changed_rhs = False
+
+            if not expr.operator().is_unary():
+                new_rhs, changed_rhs = self._build_expression_with_index(expr.rhs, index)
+
+            if changed_lhs or changed_rhs:
+                if isinstance(expr, ScalarOp):
+                    return (ScalarOp(self.sim, new_lhs, new_rhs, expr.operator(), expr.mem), True)
+
+                if isinstance(expr, VectorOp):
+                    return (VectorOp(self.sim, new_lhs, new_rhs, expr.operator(), expr.mem), True)
+
+            return (expr, False)
+
+        if isinstance(expr, Vector):
+            values = []
+            changed = False
+
+            for value in expr._values:
+                new_value, changed_value = self._build_expression_with_index(value, index)
+                values.append(new_value)
+                changed = changed or changed_value
+
+            if changed:
+                return (Vector(self.sim, values), True)
+
+            return (expr, False)
+
+        if isinstance(expr, VectorAccess):
+            new_expr, changed = self._build_expression_with_index(expr.expr, index)
+
+            if changed:
+                return (VectorAccess(self.sim, new_expr, expr.index), True)
+
+            return (expr, False)
+
+        if isinstance(expr, PropertyAccess):
+            return (expr, False)
+
+        if isinstance(expr, Property):
+            return (expr[index], True)
+
+        changed = False
+        for child in expr.children():
+            _, changed_child = self._build_expression_with_index(child, index)
+            changed = changed or changed_child
+
+        return (expr, changed)
 
     @pairs_inline
     def lower(self):
-        Assign(self.sim, self._reduction_variable, self._reduction_variable + self._expr)
+        Assign(self.sim, self._reduction_variable, self._reduction_variable + self._expr_i)
 
         if self.sim._compute_half:
             for _ in Filter(self.sim,
                             ScalarOp.and_op(self._j < self.sim.nlocal,
                                             ScalarOp.cmp(self.sim.particle_flags[self._j] & Flags.Fixed, 0))):
 
-                Assign(self.sim, self._prop[self._j], self._prop[self._j] - self._expr)
+                Assign(self.sim, self._prop[self._j], self._prop[self._j] - self._expr_j)
diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py
index b7e1e42..a6f5da5 100644
--- a/src/pairs/ir/arrays.py
+++ b/src/pairs/ir/arrays.py
@@ -160,7 +160,7 @@ class ArrayAccess(ASTTerm):
         self.inlined = True
         return self
 
-    def copy(self):
+    def copy(self, deep=False):
         return ArrayAccess(self.sim, self.array, self.partial_indexes)
 
     def check_and_set_flat_index(self):
diff --git a/src/pairs/ir/features.py b/src/pairs/ir/features.py
index 7067073..d709336 100644
--- a/src/pairs/ir/features.py
+++ b/src/pairs/ir/features.py
@@ -172,7 +172,7 @@ class FeaturePropertyAccess(ASTTerm):
     def __str__(self):
         return f"FeaturePropertyAccess<{self.feature_prop}, {self.index}>"
 
-    def copy(self):
+    def copy(self, deep=False):
         return FeaturePropertyAccess(self.sim, self.feature_prop, self.index)
 
     def vector_index(self, dimension):
diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py
index ff54c94..cb45f56 100644
--- a/src/pairs/ir/lit.py
+++ b/src/pairs/ir/lit.py
@@ -46,7 +46,7 @@ class Lit(ASTTerm):
     def __req__(self, other):
         return self.__cmp__(other)
 
-    def copy(self):
+    def copy(self, deep=False):
         return Lit(self.sim, self.value)
 
     def type(self):
diff --git a/src/pairs/ir/matrices.py b/src/pairs/ir/matrices.py
index 38e3341..25553a9 100644
--- a/src/pairs/ir/matrices.py
+++ b/src/pairs/ir/matrices.py
@@ -44,16 +44,11 @@ class MatrixOp(ASTTerm):
     def operator(self):
         return self.op
 
-    def reassign(self, lhs, rhs, op):
-        self.lhs = Lit.cvt(self.sim, lhs)
-        self.rhs = Lit.cvt(self.sim, rhs)
-        self.op = op
-
-    def copy(self):
-        return MatrixOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+    def copy(self, deep=False):
+        if deep:
+            return MatrixOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem)
 
-    def match(self, matrix_op):
-        return self.lhs == matrix_op.lhs and self.rhs == matrix_op.rhs and self.op == matrix_op.operator()
+        return MatrixOp(self.sim, self.lhs, self.rhs, self.op, self.mem)
 
     def add_terminal(self, terminal):
         self.terminals.add(terminal)
diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py
index bd63fbf..42a5564 100644
--- a/src/pairs/ir/mutator.py
+++ b/src/pairs/ir/mutator.py
@@ -44,6 +44,12 @@ class Mutator:
         ast_node._expr = self.mutate(ast_node._expr)
         ast_node._j = self.mutate(ast_node._j)
 
+        if ast_node._expr_i is not None:
+            ast_node._expr_i = self.mutate(ast_node._expr_i)
+
+        if ast_node._expr_j is not None:
+            ast_node._expr_j = self.mutate(ast_node._expr_j)
+
         if ast_node._reduction_variable is not None:
             ast_node._reduction_variable = self.mutate(ast_node._reduction_variable)
 
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index b4867f8..85024f1 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -125,7 +125,7 @@ class PropertyAccess(ASTTerm):
         _acc_class = AccessorClass.from_type(self.prop.type())
         return _acc_class(self.sim, self, Lit.cvt(self.sim, index))
 
-    def copy(self):
+    def copy(self, deep=False):
         return PropertyAccess(self.sim, self.prop, self.index)
 
     def vector_index(self, dimension):
@@ -291,7 +291,7 @@ class ContactPropertyAccess(ASTTerm):
         _acc_class = AccessorClass.from_type(self.contact_prop.type())
         return _acc_class(self.sim, self, Lit.cvt(self.sim, index))
 
-    def copy(self):
+    def copy(self, deep=False):
         return ContactPropertyAccess(self.sim, self.contact_prop, self.index)
 
     def vector_index(self, dimension):
diff --git a/src/pairs/ir/quaternions.py b/src/pairs/ir/quaternions.py
index 1f2df2a..2d6a5cc 100644
--- a/src/pairs/ir/quaternions.py
+++ b/src/pairs/ir/quaternions.py
@@ -44,16 +44,11 @@ class QuaternionOp(ASTTerm):
     def operator(self):
         return self.op
 
-    def reassign(self, lhs, rhs, op):
-        self.lhs = Lit.cvt(self.sim, lhs)
-        self.rhs = Lit.cvt(self.sim, rhs)
-        self.op = op
-
-    def copy(self):
-        return QuaternionOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+    def copy(self, deep=False):
+        if deep:
+            return QuaternionOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem)
 
-    def match(self, quat_op):
-        return self.lhs == quat_op.lhs and self.rhs == quat_op.rhs and self.op == quat_op.operator()
+        return QuaternionOp(self.sim, self.lhs, self.rhs, self.op, self.mem)
 
     def add_terminal(self, terminal):
         self.terminals.add(terminal)
diff --git a/src/pairs/ir/scalars.py b/src/pairs/ir/scalars.py
index f0e30f8..1538cf1 100644
--- a/src/pairs/ir/scalars.py
+++ b/src/pairs/ir/scalars.py
@@ -37,22 +37,22 @@ class ScalarOp(ASTTerm):
         self.scalar_op_type = ScalarOp.infer_type(self.lhs, self.rhs, self.op)
         self.terminals = set()
 
-    def reassign(self, lhs, rhs, op):
-        self.lhs = Lit.cvt(self.sim, lhs)
-        self.rhs = Lit.cvt(self.sim, rhs)
-        self.op = op
-        self.scalar_op_type = ScalarOp.infer_type(self.lhs, self.rhs, self.op)
-
     def __str__(self):
         a = f"ScalarOp<{self.lhs.id()}>" if isinstance(self.lhs, ScalarOp) else self.lhs
         b = f"ScalarOp<{self.rhs.id()}>" if isinstance(self.rhs, ScalarOp) else self.rhs
         return f"ScalarOp<id={self.id()}, {a} {self.op.symbol()} {b}>"
 
-    def copy(self):
-        return ScalarOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+    def copy(self, deep=False):
+        if self.op.is_unary():
+            if deep:
+                return ScalarOp(self.sim, self.lhs.copy(True), None, self.op, self.mem)
+
+            return ScalarOp(self.sim, self.lhs, None, self.op, self.mem)
+
+        if deep:
+            return ScalarOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem)
 
-    def match(self, scalar_op):
-        return self.lhs == scalar_op.lhs and self.rhs == scalar_op.rhs and self.op == scalar_op.operator()
+        return ScalarOp(self.sim, self.lhs, self.rhs, self.op, self.mem)
 
     def x(self):
         return self.__getitem__(0)
diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py
index 4c1d279..3d96065 100644
--- a/src/pairs/ir/variables.py
+++ b/src/pairs/ir/variables.py
@@ -55,7 +55,7 @@ class Var(ASTTerm):
     def __str__(self):
         return f"Var<{self.var_name}>"
 
-    def copy(self):
+    def copy(self, deep=False):
         # Terminal copies are just themselves
         return self
 
diff --git a/src/pairs/ir/vectors.py b/src/pairs/ir/vectors.py
index b8d94b1..759b69c 100644
--- a/src/pairs/ir/vectors.py
+++ b/src/pairs/ir/vectors.py
@@ -44,16 +44,17 @@ class VectorOp(ASTTerm):
     def operator(self):
         return self.op
 
-    def reassign(self, lhs, rhs, op):
-        self.lhs = Lit.cvt(self.sim, lhs)
-        self.rhs = Lit.cvt(self.sim, rhs)
-        self.op = op
+    def copy(self, deep=False):
+        if self.op.is_unary():
+            if deep:
+                return VectorOp(self.sim, self.lhs.copy(True), None, self.op, self.mem)
+
+            return VectorOp(self.sim, self.lhs, None, self.op, self.mem)
 
-    def copy(self):
-        return VectorOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem)
+        if deep:
+            return VectorOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem)
 
-    def match(self, vector_op):
-        return self.lhs == vector_op.lhs and self.rhs == vector_op.rhs and self.op == vector_op.operator()
+        return VectorOp(self.sim, self.lhs, self.rhs, self.op, self.mem)
 
     def x(self):
         return self.__getitem__(0)
@@ -81,6 +82,12 @@ class VectorAccess(ASTTerm):
     def __str__(self):
         return f"VectorAccess<{self.expr}, {self.index}>"
 
+    def copy(self, deep=False):
+        if deep:
+            return VectorAccess(self.sim, self.expr.copy(), self.index.copy())
+
+        return VectorAccess(self.sim, self.expr, self.index)
+
     def type(self):
         return Types.Real
 
@@ -103,7 +110,8 @@ class Vector(ASTTerm):
         self.terminals = set()
 
     def __str__(self):
-        return f"Vector<{self._values}>"
+        values_str = ", ".join([str(v) for v in self._values])
+        return f"Vector<{values_str}>"
 
     def __getitem__(self, index):
         return VectorAccess(self.sim, self, Lit.cvt(self.sim, index))
@@ -114,6 +122,12 @@ class Vector(ASTTerm):
     def name(self):
         return f"vec{self.id()}" + self.label_suffix()
 
+    def copy(self, deep=False):
+        if deep:
+            return Vector(self.sim, [value.copy(True) for value in self._values])
+
+        return Vector(self.sim, [value for value in self._values])
+
     def type(self):
         return Types.Vector
 
diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py
index 573ee6a..d27f86e 100644
--- a/src/pairs/mapping/funcs.py
+++ b/src/pairs/mapping/funcs.py
@@ -151,7 +151,7 @@ class BuildParticleIR(ast.NodeVisitor):
 
         if self.keywords.exists(func):
             if func == 'apply':
-                args += [self.ctx_symbols['__j__']]
+                args += [self.ctx_symbols['__i__'], self.ctx_symbols['__j__']]
 
             return self.keywords(func, args)
 
diff --git a/src/pairs/mapping/keywords.py b/src/pairs/mapping/keywords.py
index 9ae93a6..ad10c96 100644
--- a/src/pairs/mapping/keywords.py
+++ b/src/pairs/mapping/keywords.py
@@ -251,9 +251,9 @@ class Keywords:
                                   1.0 - 2.0 * quat[1] * quat[1] - 2.0 * quat[2] * quat[2] ])
 
     def keyword_apply(self, args):
-        assert len(args) == 3, "apply() keyword requires two parameters."
+        assert len(args) == 4, "apply() keyword requires two parameters."
         prop = args[0]
         expr = args[1]
         assert isinstance(prop, Property), "apply(): First argument must be a property."
         assert prop.type() == expr.type(), "apply(): Property and expression must be of same type."
-        Apply(self.sim, prop, expr, args[2])
+        Apply(self.sim, prop, expr, args[2], args[3])
diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py
index 97c7799..c2d02fc 100644
--- a/src/pairs/transformations/devices.py
+++ b/src/pairs/transformations/devices.py
@@ -83,7 +83,7 @@ class AddDeviceKernels(Mutator):
                         kernel_name = f"{ast_node.name}_kernel{kernel_id}"
                         kernel = ast_node.sim.find_kernel_by_name(kernel_name)
                         if kernel is None:
-                            kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy()), s.block)
+                            kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy(True)), s.block)
                             kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
                             kernel_id += 1
 
-- 
GitLab