From a6db87e9f1eaa82cf45326768126839dac1a040b Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Fri, 21 Apr 2023 03:29:47 +0200
Subject: [PATCH] Update code for features

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/{eam.py => dem.py} |  0
 src/pairs/ir/features.py    | 87 ++++++++++---------------------------
 src/pairs/ir/properties.py  | 24 +++++++---
 src/pairs/sim/simulation.py |  5 +--
 4 files changed, 45 insertions(+), 71 deletions(-)
 rename examples/{eam.py => dem.py} (100%)

diff --git a/examples/eam.py b/examples/dem.py
similarity index 100%
rename from examples/eam.py
rename to examples/dem.py
diff --git a/src/pairs/ir/features.py b/src/pairs/ir/features.py
index 5e82e2a..5434cff 100644
--- a/src/pairs/ir/features.py
+++ b/src/pairs/ir/features.py
@@ -18,15 +18,11 @@ class Features:
         self.features.append(f)
         return p
 
-    def add_property(self, feature, prop):
-        self.feature_properties.append([feature, prop])
-        return prop
-
     def nfeatures(self):
         return len(self.features)
 
     def find(self, f_name):
-        prop = [f for f in self.features if f.name() == f_name]
+        feature = [f for f in self.features if f.name() == f_name]
         if feature:
             return feature[0]
 
@@ -43,6 +39,7 @@ class Feature(ASTNode):
         super().__init__(sim)
         self.feature_id = Feature.last_feature_id
         self.feature_name = name
+        self.feature_count = self.sim.add_var(f"count_{self.feature_name}", Types.Int32)
         Feature.last_feature_id += 1
 
     def __str__(self):
@@ -54,37 +51,30 @@ class Feature(ASTNode):
     def name(self):
         return self.feature_name
 
-    #def __getitem__(self, expr):
-    #    return PropertyAccess(self.sim, self, expr)
+    def count(self):
+        return self.feature_count
 
 
-class FeaturePropertyAccess(ASTTerm, VectorExpression):
-    last_prop_acc = 0
+class FeatureAccess(ASTTerm):
+    last_feat_acc = 0
 
     def new_id():
-        PropertyAccess.last_prop_acc += 1
-        return PropertyAccess.last_prop_acc - 1
+        PropertyAccess.last_feat_acc += 1
+        return PropertyAccess.last_feat_acc - 1
 
-    def __init__(self, sim, prop, index):
+    def __init__(self, sim, feature, index):
         super().__init__(sim)
-        self.acc_id = PropertyAccess.new_id()
-        self.prop = prop
+        self.acc_id = FeatureAccess.new_id()
+        self.feature = feature
         self.index = Lit.cvt(sim, index)
         self.inlined = False
         self.terminals = set()
 
     def __str__(self):
-        return f"PropertyAccess<{self.prop}, {self.index}>"
-
-    def vector_index(self, v_index):
-        sizes = self.prop.sizes()
-        layout = self.prop.layout()
-        index = self.index * sizes[0] + v_index if layout == Layouts.AoS else \
-                v_index * sizes[1] + self.index if layout == Layouts.SoA else \
-                None
+        return f"FeatureAccess<{self.feature}, {self.index}>"
 
-        assert index is not None, "Invalid data layout"
-        return index
+    def copy(self):
+        return FeatureAccess(self.sim, self.feature, self.index)
 
     def inline_rec(self):
         self.inlined = True
@@ -93,6 +83,15 @@ class FeaturePropertyAccess(ASTTerm, VectorExpression):
     def propagate_through(self):
         return []
 
+    def set(self, other):
+        return self.sim.add_statement(Assign(self.sim, self, other))
+
+    def add(self, other):
+        return self.sim.add_statement(Assign(self.sim, self, self + other))
+
+    def sub(self, other):
+        return self.sim.add_statement(Assign(self.sim, self, self - other))
+
     def id(self):
         return self.acc_id
 
@@ -103,42 +102,4 @@ class FeaturePropertyAccess(ASTTerm, VectorExpression):
         self.terminals.add(terminal)
 
     def children(self):
-        return [self.prop, self.index] + list(super().children())
-
-    def __getitem__(self, index):
-        super().__getitem__(index)
-        return VectorAccess(self.sim, self, Lit.cvt(self.sim, index))
-
-
-class RegisterProperty(ASTNode):
-    def __init__(self, sim, prop, sizes):
-        super().__init__(sim)
-        self.prop = prop
-        self.sizes_list = [Lit.cvt(sim, s) for s in sizes]
-        self.sim.add_statement(self)
-
-    def property(self):
-        return self.prop
-
-    def sizes(self):
-        return self.sizes_list
-
-    def __str__(self):
-        return f"RegisterProperty<{self.prop.name()}>"
-
-
-class ReallocProperty(ASTNode):
-    def __init__(self, sim, prop, sizes):
-        super().__init__(sim)
-        self.prop = prop
-        self.sizes_list = [Lit.cvt(sim, s) for s in sizes]
-        self.sim.add_statement(self)
-
-    def property(self):
-        return self.prop
-
-    def sizes(self):
-        return self.sizes_list
-
-    def __str__(self):
-        return f"ReallocProperty<{self.prop.name()}>"
+        return [self.prop, self.index]
diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py
index 0bf42df..e43d386 100644
--- a/src/pairs/ir/properties.py
+++ b/src/pairs/ir/properties.py
@@ -14,8 +14,8 @@ class Properties:
         self.capacities = []
         self.defs = {}
 
-    def add(self, p_name, p_type, p_value, p_volatile, p_layout=Layouts.AoS):
-        p = Property(self.sim, p_name, p_type, p_value, p_volatile, p_layout)
+    def add(self, p_name, p_type, p_value, p_volatile, p_layout=Layouts.AoS, p_feature=None):
+        p = Property(self.sim, p_name, p_type, p_value, p_volatile, p_layout, p_feature)
         self.props.append(p)
         self.defs[p_name] = p_value
         return p
@@ -52,12 +52,13 @@ class Properties:
 class Property(ASTNode):
     last_prop_id = 0
 
-    def __init__(self, sim, name, dtype, default, volatile, layout=Layouts.AoS):
+    def __init__(self, sim, name, dtype, default, volatile, layout=Layouts.AoS, feature=None):
         super().__init__(sim)
         self.prop_id = Property.last_prop_id
         self.prop_name = name
         self.prop_type = dtype
         self.prop_layout = layout
+        self.prop_feature = feature
         self.default_value = default
         self.volatile = volatile
         self.device_flag = False
@@ -78,6 +79,9 @@ class Property(ASTNode):
     def layout(self):
         return self.prop_layout
 
+    def feature(self):
+        return self.prop_feature
+
     def default(self):
         return self.default_value
 
@@ -85,7 +89,8 @@ class Property(ASTNode):
         return 1 if self.prop_type != Types.Vector else 2
 
     def sizes(self):
-        return [self.sim.particle_capacity] if self.prop_type != Types.Vector else [self.sim.ndims(), self.sim.particle_capacity]
+        return [self.sim.particle_capacity] if self.prop_type != Types.Vector \
+               else [self.sim.ndims(), self.sim.particle_capacity]
 
     def __getitem__(self, expr):
         return PropertyAccess(self.sim, self, expr)
@@ -102,7 +107,16 @@ class PropertyAccess(ASTTerm, VectorExpression):
         super().__init__(sim)
         self.acc_id = PropertyAccess.new_id()
         self.prop = prop
-        self.index = Lit.cvt(sim, index)
+
+        if prop.feature() == None:
+            assert isinstance(index, int), "Only one index must be used for feature property!"
+            self.index = Lit.cvt(sim, index)
+
+        else:
+            assert isinstance(index, tuple), "Two indexes must be used for feature property!"
+            feature = self.prop.feature()
+            self.index = Lit.cvt(sim, feature[index[0]] * feature.count() + feature[index[1]])
+
         self.inlined = False
         self.terminals = set()
 
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 52b8edc..d6ccee4 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -113,9 +113,8 @@ class Simulation:
     def add_feature_property(self, feature_name, prop_name, prop_type):
         feature = self.feature(feature_name)
         assert feature is not None, f"Feature not found: {feature_name}"
-        feature_prop = feature.add_property(prop_name, prop_type)
-        self.features.add_property(feature, feature_prop)
-        return feature_prop
+        assert self.property(prop_name) is None, f"Property already defined: {prop_name}"
+        return self.properties.add(prop_name, prop_type, value, vol, feature=feature)
 
     def property(self, prop_name):
         return self.properties.find(prop_name)
-- 
GitLab