diff --git a/examples/dem.py b/examples/dem.py index df530f063a31ec6e5c3066d890ac0013dd2d0f86..357d03a7099657440d9a756768e41aecbe257442 100644 --- a/examples/dem.py +++ b/examples/dem.py @@ -35,8 +35,8 @@ def linear_spring_dashpot(i, j): f_friction_abs_dynamic = friction_dynamic[i, j] * length(fN) tan_vel_threshold = 1e-8 - cond1 = sticking and length(rel_vel_t) < tan_vel_threshold and fTLS_len < f_friction_abs_static - cond2 = sticking and fTLS_len < f_friction_abs_dynamic + cond1 = sticking == 1 and length(rel_vel_t) < tan_vel_threshold and fTLS_len < f_friction_abs_static + cond2 = sticking == 1 and fTLS_len < f_friction_abs_dynamic f_friction_abs = select(cond1, f_friction_abs_static, f_friction_abs_dynamic) n_sticking = select(cond1 or cond2 or fTLS_len < f_friction_abs_dynamic, 1, 0) diff --git a/src/pairs/analysis/bin_ops.py b/src/pairs/analysis/bin_ops.py index be4980d2db0cf99efcf298726a42f3d02bbcd347..62a760fbd4031c1afdf2b15b6604b3f765b40b87 100644 --- a/src/pairs/analysis/bin_ops.py +++ b/src/pairs/analysis/bin_ops.py @@ -26,6 +26,11 @@ class SetBinOpTerminals(Visitor): self.visit_children(ast_node) self.elems.pop() + def visit_ContactPropertyAccess(self, ast_node): + self.elems.append(ast_node) + self.visit_children(ast_node) + self.elems.pop() + def visit_FeaturePropertyAccess(self, ast_node): self.elems.append(ast_node) self.visit_children(ast_node) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 5d2bf5b5e54528de2772f9fdb5d719c4a7e1b29f..3f1de2aeb0102c2efdc6d07629e6ec3680d9b0aa 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -17,7 +17,7 @@ from pairs.ir.math import Ceil, Sqrt from pairs.ir.memory import Malloc, Realloc from pairs.ir.module import ModuleCall from pairs.ir.particle_attributes import ParticleAttributeList -from pairs.ir.properties import Property, PropertyAccess, RegisterProperty, ReallocProperty, ContactPropertyAccess, RegisterContactProperty +from pairs.ir.properties import Property, PropertyAccess, RegisterProperty, ReallocProperty, ContactProperty, ContactPropertyAccess, RegisterContactProperty from pairs.ir.select import Select from pairs.ir.sizeof import Sizeof from pairs.ir.types import Types @@ -709,6 +709,9 @@ class CGen: expr = self.generate_expression(ast_node.expr) return f"ceil({expr})" + if isinstance(ast_node, ContactProperty): + return ast_node.name() + if isinstance(ast_node, Deref): var = self.generate_expression(ast_node.var) return f"(*{var})" diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 942a0f244755423c86166d6be05c8ea6b21aedbf..22ed5c0619003d3586624d0ab635ff65d9cb91d2 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -1,3 +1,4 @@ +import time from pairs.analysis import Analysis from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules, AddDeviceReferencesToModules @@ -14,11 +15,15 @@ class Transformations: self._module_resizes = None def apply(self, transformation, data=None): + print(f"Applying transformation: {type(transformation).__name__}... ", end="") + start = time.time() transformation.set_ast(self._ast) if data is not None: transformation.set_data(data) self._ast = transformation.mutate() + elapsed = time.time() - start + print(f"{elapsed}s elapsed.") def analysis(self): return Analysis(self._ast) diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index cb8e0638da7be3dcef6eda2e78286b783bd374ef..2087b1f315cf04bfba5c3245dd0c830ab17b25d6 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -234,6 +234,22 @@ class AddExpressionDeclarations(Mutator): return ast_node + def mutate_ContactPropertyAccess(self, ast_node): + writing = self.writing + ast_node.contact_prop = self.mutate(ast_node.contact_prop) + self.writing = False + ast_node.index = self.mutate(ast_node.index) + ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} + self.writing = writing + + if self.writing is False and ast_node.inlined is False: + contact_prop_access_id = id(ast_node) + if contact_prop_access_id not in self.declared_exprs and contact_prop_access_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(contact_prop_access_id) + + return ast_node + def mutate_FeaturePropertyAccess(self, ast_node): assert self.writing is False, "Cannot change feature property!" ast_node.feature_prop = self.mutate(ast_node.feature_prop)