From 6fdee233fa56beed509153f98e5669b78054d64e Mon Sep 17 00:00:00 2001 From: Rafael Ravedutti <rafaelravedutti@gmail.com> Date: Mon, 7 Aug 2023 18:08:15 +0200 Subject: [PATCH] Add some fixes Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com> --- examples/dem.py | 4 ++-- src/pairs/analysis/bin_ops.py | 5 +++++ src/pairs/code_gen/cgen.py | 5 ++++- src/pairs/transformations/__init__.py | 5 +++++ src/pairs/transformations/expressions.py | 16 ++++++++++++++++ 5 files changed, 32 insertions(+), 3 deletions(-) diff --git a/examples/dem.py b/examples/dem.py index df530f0..357d03a 100644 --- a/examples/dem.py +++ b/examples/dem.py @@ -35,8 +35,8 @@ def linear_spring_dashpot(i, j): f_friction_abs_dynamic = friction_dynamic[i, j] * length(fN) tan_vel_threshold = 1e-8 - cond1 = sticking and length(rel_vel_t) < tan_vel_threshold and fTLS_len < f_friction_abs_static - cond2 = sticking and fTLS_len < f_friction_abs_dynamic + cond1 = sticking == 1 and length(rel_vel_t) < tan_vel_threshold and fTLS_len < f_friction_abs_static + cond2 = sticking == 1 and fTLS_len < f_friction_abs_dynamic f_friction_abs = select(cond1, f_friction_abs_static, f_friction_abs_dynamic) n_sticking = select(cond1 or cond2 or fTLS_len < f_friction_abs_dynamic, 1, 0) diff --git a/src/pairs/analysis/bin_ops.py b/src/pairs/analysis/bin_ops.py index be4980d..62a760f 100644 --- a/src/pairs/analysis/bin_ops.py +++ b/src/pairs/analysis/bin_ops.py @@ -26,6 +26,11 @@ class SetBinOpTerminals(Visitor): self.visit_children(ast_node) self.elems.pop() + def visit_ContactPropertyAccess(self, ast_node): + self.elems.append(ast_node) + self.visit_children(ast_node) + self.elems.pop() + def visit_FeaturePropertyAccess(self, ast_node): self.elems.append(ast_node) self.visit_children(ast_node) diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 5d2bf5b..3f1de2a 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -17,7 +17,7 @@ from pairs.ir.math import Ceil, Sqrt from pairs.ir.memory import Malloc, Realloc from pairs.ir.module import ModuleCall from pairs.ir.particle_attributes import ParticleAttributeList -from pairs.ir.properties import Property, PropertyAccess, RegisterProperty, ReallocProperty, ContactPropertyAccess, RegisterContactProperty +from pairs.ir.properties import Property, PropertyAccess, RegisterProperty, ReallocProperty, ContactProperty, ContactPropertyAccess, RegisterContactProperty from pairs.ir.select import Select from pairs.ir.sizeof import Sizeof from pairs.ir.types import Types @@ -709,6 +709,9 @@ class CGen: expr = self.generate_expression(ast_node.expr) return f"ceil({expr})" + if isinstance(ast_node, ContactProperty): + return ast_node.name() + if isinstance(ast_node, Deref): var = self.generate_expression(ast_node.var) return f"(*{var})" diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index 942a0f2..22ed5c0 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -1,3 +1,4 @@ +import time from pairs.analysis import Analysis from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules, AddDeviceReferencesToModules @@ -14,11 +15,15 @@ class Transformations: self._module_resizes = None def apply(self, transformation, data=None): + print(f"Applying transformation: {type(transformation).__name__}... ", end="") + start = time.time() transformation.set_ast(self._ast) if data is not None: transformation.set_data(data) self._ast = transformation.mutate() + elapsed = time.time() - start + print(f"{elapsed}s elapsed.") def analysis(self): return Analysis(self._ast) diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index cb8e063..2087b1f 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -234,6 +234,22 @@ class AddExpressionDeclarations(Mutator): return ast_node + def mutate_ContactPropertyAccess(self, ast_node): + writing = self.writing + ast_node.contact_prop = self.mutate(ast_node.contact_prop) + self.writing = False + ast_node.index = self.mutate(ast_node.index) + ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} + self.writing = writing + + if self.writing is False and ast_node.inlined is False: + contact_prop_access_id = id(ast_node) + if contact_prop_access_id not in self.declared_exprs and contact_prop_access_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(contact_prop_access_id) + + return ast_node + def mutate_FeaturePropertyAccess(self, ast_node): assert self.writing is False, "Cannot change feature property!" ast_node.feature_prop = self.mutate(ast_node.feature_prop) -- GitLab