diff --git a/examples/dem.py b/examples/dem.py index 60f86c0c3f899927b0375d71f9afe31a6d8f4a43..b8630d2e28ad01dbbb407093ddc07231d4b93c11 100644 --- a/examples/dem.py +++ b/examples/dem.py @@ -75,7 +75,7 @@ def linear_spring_dashpot(i, j): partial_force = fN + fT apply(force, partial_force) - apply(torque, cross(contact_point(i, j) - position[i], partial_force)) + apply(torque, cross(contact_point(i, j) - position, partial_force)) def euler(i): diff --git a/src/pairs/ir/apply.py b/src/pairs/ir/apply.py index 6ae09dddf81dc77dc7d2c127a06de9ecbb79d3b0..1501fe2a85508cc46fdf04148daf1c60ebb581c2 100644 --- a/src/pairs/ir/apply.py +++ b/src/pairs/ir/apply.py @@ -3,23 +3,27 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.block import pairs_inline from pairs.ir.branches import Filter from pairs.ir.lit import Lit -from pairs.ir.properties import Property +from pairs.ir.mutator import Mutator +from pairs.ir.properties import Property, PropertyAccess from pairs.ir.scalars import ScalarOp -from pairs.ir.vectors import Vector +from pairs.ir.vectors import Vector, VectorAccess, VectorOp from pairs.ir.types import Types from pairs.sim.flags import Flags from pairs.sim.lowerable import FinalLowerable, Lowerable class Apply(Lowerable): - def __init__(self, sim, prop, expr, j): + def __init__(self, sim, prop, expr, i, j): assert isinstance(prop, Property), "Apply(): Destination must of Property type." assert prop.type() == expr.type(), "Apply(): Property and expression must be of same type." assert sim.current_apply_list() is not None, "Apply(): Not used within particle interaction." super().__init__(sim) self._prop = prop self._expr = Lit.cvt(sim, expr) + self._i = i self._j = j + self._expr_i = self.build_expression_with_index(self._expr, self._i) + self._expr_j = self.build_expression_with_index(self._expr, self._j) if sim._compute_half else None self._reduction_variable = None sim.current_apply_list().add(self) sim.add_statement(self) @@ -40,15 +44,76 @@ class Apply(Lowerable): return self._reduction_variable def children(self): - return [self._prop, self._expr, self._reduction_variable, self._j] + return [self._prop, self._i, self._j] + \ + [self._reduction_variable] if self._reduction_variable is not None else [] + \ + [self._expr] if self._expr is not None else [] + \ + [self._expr_i] if self._expr_i is not None else [] + \ + [self._expr_j] if self._expr_j is not None else [] + + def build_expression_with_index(self, expr, index): + return self._build_expression_with_index(expr, index)[0] + + # TODO: This method should comprise all operators and dynamic data types, it would also be + # better to provide a better way to implement it such as a Mutator or Visitor + def _build_expression_with_index(self, expr, index): + if isinstance(expr, (ScalarOp, VectorOp)): + new_lhs, changed_lhs = self._build_expression_with_index(expr.lhs, index) + changed_rhs = False + + if not expr.operator().is_unary(): + new_rhs, changed_rhs = self._build_expression_with_index(expr.rhs, index) + + if changed_lhs or changed_rhs: + if isinstance(expr, ScalarOp): + return (ScalarOp(self.sim, new_lhs, new_rhs, expr.operator(), expr.mem), True) + + if isinstance(expr, VectorOp): + return (VectorOp(self.sim, new_lhs, new_rhs, expr.operator(), expr.mem), True) + + return (expr, False) + + if isinstance(expr, Vector): + values = [] + changed = False + + for value in expr._values: + new_value, changed_value = self._build_expression_with_index(value, index) + values.append(new_value) + changed = changed or changed_value + + if changed: + return (Vector(self.sim, values), True) + + return (expr, False) + + if isinstance(expr, VectorAccess): + new_expr, changed = self._build_expression_with_index(expr.expr, index) + + if changed: + return (VectorAccess(self.sim, new_expr, expr.index), True) + + return (expr, False) + + if isinstance(expr, PropertyAccess): + return (expr, False) + + if isinstance(expr, Property): + return (expr[index], True) + + changed = False + for child in expr.children(): + _, changed_child = self._build_expression_with_index(child, index) + changed = changed or changed_child + + return (expr, changed) @pairs_inline def lower(self): - Assign(self.sim, self._reduction_variable, self._reduction_variable + self._expr) + Assign(self.sim, self._reduction_variable, self._reduction_variable + self._expr_i) if self.sim._compute_half: for _ in Filter(self.sim, ScalarOp.and_op(self._j < self.sim.nlocal, ScalarOp.cmp(self.sim.particle_flags[self._j] & Flags.Fixed, 0))): - Assign(self.sim, self._prop[self._j], self._prop[self._j] - self._expr) + Assign(self.sim, self._prop[self._j], self._prop[self._j] - self._expr_j) diff --git a/src/pairs/ir/arrays.py b/src/pairs/ir/arrays.py index b7e1e4245630323a983aa1bba36115162e7a783b..a6f5da5a0016f686213de510f18edccc10d46d46 100644 --- a/src/pairs/ir/arrays.py +++ b/src/pairs/ir/arrays.py @@ -160,7 +160,7 @@ class ArrayAccess(ASTTerm): self.inlined = True return self - def copy(self): + def copy(self, deep=False): return ArrayAccess(self.sim, self.array, self.partial_indexes) def check_and_set_flat_index(self): diff --git a/src/pairs/ir/features.py b/src/pairs/ir/features.py index 7067073b67fddfbc034a2230c621e0c4f65f1fee..d709336b0971a7361e0cb5159d36bc975605cb28 100644 --- a/src/pairs/ir/features.py +++ b/src/pairs/ir/features.py @@ -172,7 +172,7 @@ class FeaturePropertyAccess(ASTTerm): def __str__(self): return f"FeaturePropertyAccess<{self.feature_prop}, {self.index}>" - def copy(self): + def copy(self, deep=False): return FeaturePropertyAccess(self.sim, self.feature_prop, self.index) def vector_index(self, dimension): diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py index ff54c94d494142f25d79a1296f00037ce543aaa7..cb45f568ff4af39c0d40015fa3f52362558f9f7f 100644 --- a/src/pairs/ir/lit.py +++ b/src/pairs/ir/lit.py @@ -46,7 +46,7 @@ class Lit(ASTTerm): def __req__(self, other): return self.__cmp__(other) - def copy(self): + def copy(self, deep=False): return Lit(self.sim, self.value) def type(self): diff --git a/src/pairs/ir/matrices.py b/src/pairs/ir/matrices.py index 38e3341f7d19e8048bfb1b76c4c9247a9bb52cf2..25553a9dace45bfead9e2da45869d973312c046d 100644 --- a/src/pairs/ir/matrices.py +++ b/src/pairs/ir/matrices.py @@ -44,16 +44,11 @@ class MatrixOp(ASTTerm): def operator(self): return self.op - def reassign(self, lhs, rhs, op): - self.lhs = Lit.cvt(self.sim, lhs) - self.rhs = Lit.cvt(self.sim, rhs) - self.op = op - - def copy(self): - return MatrixOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + def copy(self, deep=False): + if deep: + return MatrixOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem) - def match(self, matrix_op): - return self.lhs == matrix_op.lhs and self.rhs == matrix_op.rhs and self.op == matrix_op.operator() + return MatrixOp(self.sim, self.lhs, self.rhs, self.op, self.mem) def add_terminal(self, terminal): self.terminals.add(terminal) diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index bd63fbf5e32abbf6916b3f20bdb78bf906fe2983..42a55640f8dbc8979eb1f42dd198e82938d63567 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -44,6 +44,12 @@ class Mutator: ast_node._expr = self.mutate(ast_node._expr) ast_node._j = self.mutate(ast_node._j) + if ast_node._expr_i is not None: + ast_node._expr_i = self.mutate(ast_node._expr_i) + + if ast_node._expr_j is not None: + ast_node._expr_j = self.mutate(ast_node._expr_j) + if ast_node._reduction_variable is not None: ast_node._reduction_variable = self.mutate(ast_node._reduction_variable) diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py index b4867f8be15acd56e5468f0c0f830163e1a5c4cd..85024f1d5448ad092f31aa77d87aec77520ed9aa 100644 --- a/src/pairs/ir/properties.py +++ b/src/pairs/ir/properties.py @@ -125,7 +125,7 @@ class PropertyAccess(ASTTerm): _acc_class = AccessorClass.from_type(self.prop.type()) return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) - def copy(self): + def copy(self, deep=False): return PropertyAccess(self.sim, self.prop, self.index) def vector_index(self, dimension): @@ -291,7 +291,7 @@ class ContactPropertyAccess(ASTTerm): _acc_class = AccessorClass.from_type(self.contact_prop.type()) return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) - def copy(self): + def copy(self, deep=False): return ContactPropertyAccess(self.sim, self.contact_prop, self.index) def vector_index(self, dimension): diff --git a/src/pairs/ir/quaternions.py b/src/pairs/ir/quaternions.py index 1f2df2a83b792a19f6c78a714d0568aac13daf20..2d6a5cc2d970a6a17e12c8c2a3578286da81db5f 100644 --- a/src/pairs/ir/quaternions.py +++ b/src/pairs/ir/quaternions.py @@ -44,16 +44,11 @@ class QuaternionOp(ASTTerm): def operator(self): return self.op - def reassign(self, lhs, rhs, op): - self.lhs = Lit.cvt(self.sim, lhs) - self.rhs = Lit.cvt(self.sim, rhs) - self.op = op - - def copy(self): - return QuaternionOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + def copy(self, deep=False): + if deep: + return QuaternionOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem) - def match(self, quat_op): - return self.lhs == quat_op.lhs and self.rhs == quat_op.rhs and self.op == quat_op.operator() + return QuaternionOp(self.sim, self.lhs, self.rhs, self.op, self.mem) def add_terminal(self, terminal): self.terminals.add(terminal) diff --git a/src/pairs/ir/scalars.py b/src/pairs/ir/scalars.py index f0e30f834a5c9eebdf91a6e5a567477e97febff4..1538cf17ab20a807c69f7d6d3b4cd94afd98d145 100644 --- a/src/pairs/ir/scalars.py +++ b/src/pairs/ir/scalars.py @@ -37,22 +37,22 @@ class ScalarOp(ASTTerm): self.scalar_op_type = ScalarOp.infer_type(self.lhs, self.rhs, self.op) self.terminals = set() - def reassign(self, lhs, rhs, op): - self.lhs = Lit.cvt(self.sim, lhs) - self.rhs = Lit.cvt(self.sim, rhs) - self.op = op - self.scalar_op_type = ScalarOp.infer_type(self.lhs, self.rhs, self.op) - def __str__(self): a = f"ScalarOp<{self.lhs.id()}>" if isinstance(self.lhs, ScalarOp) else self.lhs b = f"ScalarOp<{self.rhs.id()}>" if isinstance(self.rhs, ScalarOp) else self.rhs return f"ScalarOp<id={self.id()}, {a} {self.op.symbol()} {b}>" - def copy(self): - return ScalarOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + def copy(self, deep=False): + if self.op.is_unary(): + if deep: + return ScalarOp(self.sim, self.lhs.copy(True), None, self.op, self.mem) + + return ScalarOp(self.sim, self.lhs, None, self.op, self.mem) + + if deep: + return ScalarOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem) - def match(self, scalar_op): - return self.lhs == scalar_op.lhs and self.rhs == scalar_op.rhs and self.op == scalar_op.operator() + return ScalarOp(self.sim, self.lhs, self.rhs, self.op, self.mem) def x(self): return self.__getitem__(0) diff --git a/src/pairs/ir/variables.py b/src/pairs/ir/variables.py index 4c1d279131939dec541cfe8d5d18e73e56e92475..3d96065c726c838f5de39710a345a8d7ea68a856 100644 --- a/src/pairs/ir/variables.py +++ b/src/pairs/ir/variables.py @@ -55,7 +55,7 @@ class Var(ASTTerm): def __str__(self): return f"Var<{self.var_name}>" - def copy(self): + def copy(self, deep=False): # Terminal copies are just themselves return self diff --git a/src/pairs/ir/vectors.py b/src/pairs/ir/vectors.py index b8d94b16c5e8a3faac9ae4f1223fb1f23558c8b1..759b69ce068ee34c0653092cffec37e3e4ca46ac 100644 --- a/src/pairs/ir/vectors.py +++ b/src/pairs/ir/vectors.py @@ -44,16 +44,17 @@ class VectorOp(ASTTerm): def operator(self): return self.op - def reassign(self, lhs, rhs, op): - self.lhs = Lit.cvt(self.sim, lhs) - self.rhs = Lit.cvt(self.sim, rhs) - self.op = op + def copy(self, deep=False): + if self.op.is_unary(): + if deep: + return VectorOp(self.sim, self.lhs.copy(True), None, self.op, self.mem) + + return VectorOp(self.sim, self.lhs, None, self.op, self.mem) - def copy(self): - return VectorOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + if deep: + return VectorOp(self.sim, self.lhs.copy(True), self.rhs.copy(True), self.op, self.mem) - def match(self, vector_op): - return self.lhs == vector_op.lhs and self.rhs == vector_op.rhs and self.op == vector_op.operator() + return VectorOp(self.sim, self.lhs, self.rhs, self.op, self.mem) def x(self): return self.__getitem__(0) @@ -81,6 +82,12 @@ class VectorAccess(ASTTerm): def __str__(self): return f"VectorAccess<{self.expr}, {self.index}>" + def copy(self, deep=False): + if deep: + return VectorAccess(self.sim, self.expr.copy(), self.index.copy()) + + return VectorAccess(self.sim, self.expr, self.index) + def type(self): return Types.Real @@ -103,7 +110,8 @@ class Vector(ASTTerm): self.terminals = set() def __str__(self): - return f"Vector<{self._values}>" + values_str = ", ".join([str(v) for v in self._values]) + return f"Vector<{values_str}>" def __getitem__(self, index): return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) @@ -114,6 +122,12 @@ class Vector(ASTTerm): def name(self): return f"vec{self.id()}" + self.label_suffix() + def copy(self, deep=False): + if deep: + return Vector(self.sim, [value.copy(True) for value in self._values]) + + return Vector(self.sim, [value for value in self._values]) + def type(self): return Types.Vector diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index 573ee6a64f3f3f34fe0eaaa1ade6dbd01e65591d..d27f86e218ecd81c43afb6612b87858cfe5cedd3 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -151,7 +151,7 @@ class BuildParticleIR(ast.NodeVisitor): if self.keywords.exists(func): if func == 'apply': - args += [self.ctx_symbols['__j__']] + args += [self.ctx_symbols['__i__'], self.ctx_symbols['__j__']] return self.keywords(func, args) diff --git a/src/pairs/mapping/keywords.py b/src/pairs/mapping/keywords.py index 9ae93a69565e8b34f06875a8f0a4f7df0934158a..ad10c960b8bf87b3f828417f7f8c3049727b23f5 100644 --- a/src/pairs/mapping/keywords.py +++ b/src/pairs/mapping/keywords.py @@ -251,9 +251,9 @@ class Keywords: 1.0 - 2.0 * quat[1] * quat[1] - 2.0 * quat[2] * quat[2] ]) def keyword_apply(self, args): - assert len(args) == 3, "apply() keyword requires two parameters." + assert len(args) == 4, "apply() keyword requires two parameters." prop = args[0] expr = args[1] assert isinstance(prop, Property), "apply(): First argument must be a property." assert prop.type() == expr.type(), "apply(): Property and expression must be of same type." - Apply(self.sim, prop, expr, args[2]) + Apply(self.sim, prop, expr, args[2], args[3]) diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index 97c7799ab3a54cc6922e9d52d5ef8ca7c8e2403d..c2d02fc900568b9bbfbedb609ca933ee2bc0a5ea 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -83,7 +83,7 @@ class AddDeviceKernels(Mutator): kernel_name = f"{ast_node.name}_kernel{kernel_id}" kernel = ast_node.sim.find_kernel_by_name(kernel_name) if kernel is None: - kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy()), s.block) + kernel_body = Filter(ast_node.sim, ScalarOp.inline(s.iterator < s.max.copy(True)), s.block) kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator) kernel_id += 1