From c7c04d1e7adc858b3795a023154b0742fbce7276 Mon Sep 17 00:00:00 2001
From: Rafael Ravedutti <rafaelravedutti@gmail.com>
Date: Wed, 1 Nov 2023 17:51:17 +0100
Subject: [PATCH] Use single-precision

Signed-off-by: Rafael Ravedutti <rafaelravedutti@gmail.com>
---
 examples/dem.py                      | 20 +++---
 runtime/contact_property.hpp         |  8 +--
 runtime/pairs_common.hpp             |  9 ++-
 runtime/property.hpp                 | 22 +++----
 runtime/read_from_file.hpp           |  6 +-
 src/pairs/__init__.py                | 10 ++-
 src/pairs/code_gen/cgen.py           | 99 ++++++++++++++--------------
 src/pairs/ir/cast.py                 |  6 ++
 src/pairs/ir/lit.py                  |  2 +-
 src/pairs/ir/math.py                 |  8 +--
 src/pairs/ir/matrices.py             |  2 +-
 src/pairs/ir/quaternions.py          |  2 +-
 src/pairs/ir/scalars.py              |  2 +-
 src/pairs/ir/types.py                | 27 ++++----
 src/pairs/ir/vectors.py              |  2 +-
 src/pairs/mapping/keywords.py        |  2 +-
 src/pairs/sim/comm.py                |  8 +--
 src/pairs/sim/domain_partitioning.py |  4 +-
 src/pairs/sim/grid.py                |  4 +-
 src/pairs/sim/interaction.py         |  4 +-
 src/pairs/sim/simulation.py          |  6 +-
 21 files changed, 140 insertions(+), 113 deletions(-)

diff --git a/examples/dem.py b/examples/dem.py
index 07b68d7..0d86063 100644
--- a/examples/dem.py
+++ b/examples/dem.py
@@ -112,7 +112,7 @@ restitutionCoefficient = 0.1
 collisionTime_SI = 5e-4
 poissonsRatio = 0.22
 timeSteps = 10000
-visSpacing = 100
+visSpacing = 1
 denseBottomLayer = False
 bottomLayerOffsetFactor = 1.0
 kappa = 2.0 * (1.0 - poissonsRatio) / (2.0 - poissonsRatio) # from Thornton et al
@@ -131,26 +131,26 @@ frictionDynamic = frictionCoefficient
 psim = pairs.simulation("dem", timesteps=timeSteps)
 #psim = pairs.simulation("dem", debug=True, timesteps=timeSteps)
 psim.add_position('position')
-psim.add_property('mass', pairs.double(), 1.0)
+psim.add_property('mass', pairs.real(), 1.0)
 psim.add_property('linear_velocity', pairs.vector())
 psim.add_property('angular_velocity', pairs.vector())
 psim.add_property('force', pairs.vector(), volatile=True)
 psim.add_property('torque', pairs.vector(), volatile=True)
-psim.add_property('radius', pairs.double(), 1.0)
+psim.add_property('radius', pairs.real(), 1.0)
 psim.add_property('normal', pairs.vector())
 psim.add_property('inv_inertia', pairs.matrix())
 psim.add_property('rotation_matrix', pairs.matrix())
 psim.add_property('rotation_quat', pairs.quaternion())
 psim.add_feature('type', ntypes)
-#psim.add_feature_property('type', 'stiffness_norm', pairs.double(), [stiffnessNorm for i in range(ntypes * ntypes)])
-#psim.add_feature_property('type', 'stiffness_tan', pairs.double(), [stiffnessTan for i in range(ntypes * ntypes)])
-#psim.add_feature_property('type', 'damping_norm', pairs.double(), [dampingNorm for i in range(ntypes * ntypes)])
-#psim.add_feature_property('type', 'damping_tan', pairs.double(), [dampingTan for i in range(ntypes * ntypes)])
-psim.add_feature_property('type', 'friction_static', pairs.double(), [frictionStatic for i in range(ntypes * ntypes)])
-psim.add_feature_property('type', 'friction_dynamic', pairs.double(), [frictionDynamic for i in range(ntypes * ntypes)])
+#psim.add_feature_property('type', 'stiffness_norm', pairs.real(), [stiffnessNorm for i in range(ntypes * ntypes)])
+#psim.add_feature_property('type', 'stiffness_tan', pairs.real(), [stiffnessTan for i in range(ntypes * ntypes)])
+#psim.add_feature_property('type', 'damping_norm', pairs.real(), [dampingNorm for i in range(ntypes * ntypes)])
+#psim.add_feature_property('type', 'damping_tan', pairs.real(), [dampingTan for i in range(ntypes * ntypes)])
+psim.add_feature_property('type', 'friction_static', pairs.real(), [frictionStatic for i in range(ntypes * ntypes)])
+psim.add_feature_property('type', 'friction_dynamic', pairs.real(), [frictionDynamic for i in range(ntypes * ntypes)])
 psim.add_contact_property('is_sticking', pairs.int32(), 0)
 psim.add_contact_property('tangential_spring_displacement', pairs.vector(), [0.0, 0.0, 0.0])
-psim.add_contact_property('impact_velocity_magnitude', pairs.double(), 0.0)
+psim.add_contact_property('impact_velocity_magnitude', pairs.real(), 0.0)
 
 psim.set_domain([0.0, 0.0, 0.0, domainSize_SI[0], domainSize_SI[1], domainSize_SI[2]])
 psim.pbc([True, True, False])
diff --git a/runtime/contact_property.hpp b/runtime/contact_property.hpp
index 088322a..04950c8 100644
--- a/runtime/contact_property.hpp
+++ b/runtime/contact_property.hpp
@@ -37,10 +37,10 @@ public:
     layout_t getLayout() { return layout; }
     size_t getElemSize() {
         return  (type == Prop_Integer) ? sizeof(int) :
-                (type == Prop_Float) ? sizeof(double) :
-                (type == Prop_Vector) ? sizeof(double) :
-                (type == Prop_Matrix) ? sizeof(double) :
-                (type == Prop_Quaternion) ? sizeof(double) : 0;
+                (type == Prop_Real) ? sizeof(real_t) :
+                (type == Prop_Vector) ? sizeof(real_t) :
+                (type == Prop_Matrix) ? sizeof(real_t) :
+                (type == Prop_Quaternion) ? sizeof(real_t) : 0;
     }
 };
 
diff --git a/runtime/pairs_common.hpp b/runtime/pairs_common.hpp
index 9a10b46..4da8dea 100644
--- a/runtime/pairs_common.hpp
+++ b/runtime/pairs_common.hpp
@@ -2,7 +2,12 @@
 
 #pragma once
 
-typedef double real_t;
+//#ifdef USE_DOUBLE_PRECISION
+//typedef double real_t;
+//#else
+typedef float real_t;
+//#endif
+
 typedef int array_t;
 typedef int property_t;
 typedef int layout_t;
@@ -10,7 +15,7 @@ typedef int layout_t;
 enum PropertyType {
     Prop_Invalid = -1,
     Prop_Integer = 0,
-    Prop_Float,
+    Prop_Real,
     Prop_Vector,
     Prop_Matrix,
     Prop_Quaternion
diff --git a/runtime/property.hpp b/runtime/property.hpp
index 16016b8..58b5594 100644
--- a/runtime/property.hpp
+++ b/runtime/property.hpp
@@ -37,10 +37,10 @@ public:
     layout_t getLayout() { return layout; }
     size_t getElemSize() {
         return  (type == Prop_Integer) ? sizeof(int) :
-                (type == Prop_Float) ? sizeof(double) :
-                (type == Prop_Vector) ? sizeof(double) :
-                (type == Prop_Matrix) ? sizeof(double) :
-                (type == Prop_Quaternion) ? sizeof(double) : 0;
+                (type == Prop_Real) ? sizeof(real_t) :
+                (type == Prop_Vector) ? sizeof(real_t) :
+                (type == Prop_Matrix) ? sizeof(real_t) :
+                (type == Prop_Quaternion) ? sizeof(real_t) : 0;
     }
 };
 
@@ -51,14 +51,14 @@ public:
 
 class FloatProperty : public Property {
 public:
-    inline double &operator()(int i) { return static_cast<double *>(h_ptr)[i]; }
+    inline real_t &operator()(int i) { return static_cast<real_t *>(h_ptr)[i]; }
 };
 
 class VectorProperty : public Property {
 public:
-    double &operator()(int i, int j) {
+    real_t &operator()(int i, int j) {
         PAIRS_ASSERT(type != Prop_Invalid && layout != Invalid && sx > 0 && sy > 0);
-        double *dptr = static_cast<double *>(h_ptr);
+        real_t *dptr = static_cast<real_t *>(h_ptr);
         if(layout == AoS) { return dptr[i * sy + j]; }
         if(layout == SoA) { return dptr[j * sx + i]; }
         PAIRS_ERROR("VectorProperty::operator[]: Invalid data layout!");
@@ -68,9 +68,9 @@ public:
 
 class MatrixProperty : public Property {
 public:
-    double &operator()(int i, int j) {
+    real_t &operator()(int i, int j) {
         PAIRS_ASSERT(type != Prop_Invalid && layout != Invalid && sx > 0 && sy > 0);
-        double *dptr = static_cast<double *>(h_ptr);
+        real_t *dptr = static_cast<real_t *>(h_ptr);
         if(layout == AoS) { return dptr[i * sy + j]; }
         if(layout == SoA) { return dptr[j * sx + i]; }
         PAIRS_ERROR("MatrixProperty::operator[]: Invalid data layout!");
@@ -80,9 +80,9 @@ public:
 
 class QuaternionProperty : public Property {
 public:
-    double &operator()(int i, int j) {
+    real_t &operator()(int i, int j) {
         PAIRS_ASSERT(type != Prop_Invalid && layout != Invalid && sx > 0 && sy > 0);
-        double *dptr = static_cast<double *>(h_ptr);
+        real_t *dptr = static_cast<real_t *>(h_ptr);
         if(layout == AoS) { return dptr[i * sy + j]; }
         if(layout == SoA) { return dptr[j * sx + i]; }
         PAIRS_ERROR("QuaternionProperty::operator[]: Invalid data layout!");
diff --git a/runtime/read_from_file.hpp b/runtime/read_from_file.hpp
index fc40714..0173b46 100644
--- a/runtime/read_from_file.hpp
+++ b/runtime/read_from_file.hpp
@@ -10,7 +10,7 @@
 
 namespace pairs {
 
-void read_grid_data(PairsSimulation *ps, const char *filename, double *grid_buffer) {
+void read_grid_data(PairsSimulation *ps, const char *filename, real_t *grid_buffer) {
     std::ifstream in_file(filename, std::ifstream::in);
     std::string line;
 
@@ -92,7 +92,7 @@ size_t read_particle_data(PairsSimulation *ps, const char *filename, const prope
                     if(prop.getName() == "flags") {
                         flags = int_ptr(n);
                     }
-                } else if(prop_type == Prop_Float) {
+                } else if(prop_type == Prop_Real) {
                     auto float_ptr = ps->getAsFloatProperty(prop);
                     float_ptr(n) = std::stod(in0);
                 } else {
@@ -147,7 +147,7 @@ size_t read_feature_data(PairsSimulation *ps, const char *filename, const int fe
                 } else if(prop_type == Prop_Integer) {
                     auto int_ptr = ps->getAsIntegerFeatureProperty(prop);
                     int_ptr(i, j) = std::stoi(in0);
-                } else if(prop_type == Prop_Float) {
+                } else if(prop_type == Prop_Real) {
                     auto float_ptr = ps->getAsFloatFeatureProperty(prop);
                     float_ptr(i, j) = std::stod(in0);
                 } else {
diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py
index ba01adc..7bf0d68 100644
--- a/src/pairs/__init__.py
+++ b/src/pairs/__init__.py
@@ -5,8 +5,8 @@ from pairs.sim.shapes import Shapes
 from pairs.sim.simulation import Simulation
 
 
-def simulation(ref, dims=3, timesteps=100, debug=False):
-    return Simulation(CGen(ref, debug), dims, timesteps)
+def simulation(ref, dims=3, timesteps=100, double_prec=False, debug=False):
+    return Simulation(CGen(ref, debug), dims, timesteps, double_prec)
 
 def target_cpu():
     return Target(Target.Backend_CPP, Target.Feature_CPU)
@@ -17,9 +17,15 @@ def target_gpu():
 def int32():
     return Types.Int32
 
+def float():
+    return Types.Float
+
 def double():
     return Types.Double
 
+def real():
+    return Types.Real
+
 def vector():
     return Types.Vector
 
diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py
index eae9daa..e94d724 100644
--- a/src/pairs/code_gen/cgen.py
+++ b/src/pairs/code_gen/cgen.py
@@ -49,6 +49,9 @@ class CGen:
     def assign_target(self, target):
         self.target = target
 
+    def real_type(self):
+        return Types.c_keyword(self.sim, Types.Real)
+
     def generate_program(self, ast_node):
         ext = ".cu" if self.target.is_gpu() else ".cpp"
         self.print = Printer(self.ref + ext)
@@ -80,14 +83,14 @@ class CGen:
             for array in self.sim.arrays.statics():
                 if array.device_flag:
                     t = array.type()
-                    tkw = Types.c_keyword(t)
+                    tkw = Types.c_keyword(self.sim, t)
                     size = self.generate_expression(ScalarOp.inline(array.alloc_size()))
                     self.print(f"__constant__ {tkw} d_{array.name()}[{size}];")
 
             for feature_prop in self.sim.feature_properties:
                 if feature_prop.device_flag:
                     t = feature_prop.type()
-                    tkw = Types.c_keyword(t)
+                    tkw = Types.c_keyword(self.sim, t)
                     size = feature_prop.array_size()
                     self.print(f"__constant__ {tkw} d_{feature_prop.name()}[{size}];")
 
@@ -117,17 +120,17 @@ class CGen:
         else:
             module_params = "PairsSimulation *pairs"
             for var in module.read_only_variables():
-                type_kw = Types.c_keyword(var.type())
+                type_kw = Types.c_keyword(self.sim, var.type())
                 decl = f"{type_kw} {var.name()}"
                 module_params += f", {decl}"
 
             for var in module.write_variables():
-                type_kw = Types.c_keyword(var.type())
+                type_kw = Types.c_keyword(self.sim, var.type())
                 decl = f"{type_kw} *{var.name()}"
                 module_params += f", {decl}"
 
             for array in module.arrays():
-                type_kw = Types.c_keyword(array.type())
+                type_kw = Types.c_keyword(self.sim, array.type())
                 decl = f"{type_kw} *{array.name()}"
                 module_params += f", {decl}"
 
@@ -136,7 +139,7 @@ class CGen:
                     module_params += f", {decl}"
 
             for prop in module.properties():
-                type_kw = Types.c_keyword(prop.type())
+                type_kw = Types.c_keyword(self.sim, prop.type())
                 decl = f"{type_kw} *{prop.name()}"
                 module_params += f", {decl}"
 
@@ -145,7 +148,7 @@ class CGen:
                     module_params += f", {decl}"
 
             for contact_prop in module.contact_properties():
-                type_kw = Types.c_keyword(contact_prop.type())
+                type_kw = Types.c_keyword(self.sim, contact_prop.type())
                 decl = f"{type_kw} *{contact_prop.name()}"
                 module_params += f", {decl}"
 
@@ -154,7 +157,7 @@ class CGen:
                     module_params += f", {decl}"
 
             for feature_prop in module.feature_properties():
-                type_kw = Types.c_keyword(feature_prop.type())
+                type_kw = Types.c_keyword(self.sim, feature_prop.type())
                 decl = f"{type_kw} *{feature_prop.name()}"
                 module_params += f", {decl}"
 
@@ -175,42 +178,42 @@ class CGen:
     def generate_kernel(self, kernel):
         kernel_params = "int range_start"
         for var in kernel.read_only_variables():
-            type_kw = Types.c_keyword(var.type())
+            type_kw = Types.c_keyword(self.sim, var.type())
             decl = f"{type_kw} {var.name()}"
             kernel_params += f", {decl}"
 
         for var in kernel.write_variables():
-            type_kw = Types.c_keyword(var.type())
+            type_kw = Types.c_keyword(self.sim, var.type())
             decl = f"{type_kw} *{var.name()}"
             kernel_params += f", {decl}"
 
         for array in kernel.arrays():
-            type_kw = Types.c_keyword(array.type())
+            type_kw = Types.c_keyword(self.sim, array.type())
             decl = f"{type_kw} *{array.name()}"
             kernel_params += f", {decl}"
 
         for prop in kernel.properties():
-            type_kw = Types.c_keyword(prop.type())
+            type_kw = Types.c_keyword(self.sim, prop.type())
             decl = f"{type_kw} *{prop.name()}"
             kernel_params += f", {decl}"
 
         for contact_prop in kernel.contact_properties():
-            type_kw = Types.c_keyword(contact_prop.type())
+            type_kw = Types.c_keyword(self.sim, contact_prop.type())
             decl = f"{type_kw} *{contact_prop.name()}"
             kernel_params += f", {decl}"
 
         for feature_prop in kernel.feature_properties():
-            type_kw = Types.c_keyword(feature_prop.type())
+            type_kw = Types.c_keyword(self.sim, feature_prop.type())
             decl = f"{type_kw} *{feature_prop.name()}"
             kernel_params += f", {decl}"
 
         for array_access in kernel.array_accesses():
-            type_kw = Types.c_keyword(array_access.type())
+            type_kw = Types.c_keyword(self.sim, array_access.type())
             decl = f"{type_kw} {array_access.name()}"
             kernel_params += f", {decl}"
 
         for scalar_op in kernel.scalar_ops():
-            type_kw = Types.c_keyword(scalar_op.type())
+            type_kw = Types.c_keyword(self.sim, scalar_op.type())
             decl = f"{type_kw} {scalar_op.name()}"
             kernel_params += f", {decl}"
 
@@ -226,7 +229,7 @@ class CGen:
     def generate_statement(self, ast_node):
         if isinstance(ast_node, DeclareStaticArray):
             t = ast_node.array.type()
-            tkw = Types.c_keyword(t)
+            tkw = Types.c_keyword(self.sim, t)
             size = self.generate_expression(ScalarOp.inline(ast_node.array.alloc_size()))
             if ast_node.array.init_value is not None:
                 v_str = str(ast_node.array.init_value)
@@ -266,7 +269,7 @@ class CGen:
             if isinstance(ast_node.elem, ArrayAccess):
                 array_access = ast_node.elem
                 array_name = self.generate_expression(array_access.array)
-                tkw = Types.c_keyword(array_access.type())
+                tkw = Types.c_keyword(self.sim, array_access.type())
                 acc_index = self.generate_expression(array_access.flat_index)
                 acc_ref = array_access.name()
                 self.print(f"const {tkw} {acc_ref} = {array_name}[{acc_index}];")
@@ -275,7 +278,7 @@ class CGen:
                 atomic_add = ast_node.elem
                 elem = self.generate_expression(atomic_add.elem)
                 value = self.generate_expression(atomic_add.value)
-                tkw = Types.c_keyword(atomic_add.type())
+                tkw = Types.c_keyword(self.sim, atomic_add.type())
                 acc_ref = atomic_add.name()
                 prefix = "" if ast_node.elem.device_flag else "host_"
 
@@ -295,10 +298,10 @@ class CGen:
                 if not contact_prop_access.is_scalar():
                     for dim in contact_prop_access.indexes_to_generate():
                         expr = self.generate_expression(contact_prop_access.vector_index(dim))
-                        self.print(f"const double {acc_ref}_{dim} = {prop_name}[{expr}];")
+                        self.print(f"const {self.real_type()} {acc_ref}_{dim} = {prop_name}[{expr}];")
 
                 else:
-                    tkw = Types.c_keyword(contact_prop_access.type())
+                    tkw = Types.c_keyword(self.sim, contact_prop_access.type())
                     acc_index = self.generate_expression(contact_prop_access.index)
                     self.print(f"const {tkw} {acc_ref} = {prop_name}[{acc_index}];")
 
@@ -311,10 +314,10 @@ class CGen:
                 if not feature_prop_access.is_scalar():
                     for dim in feature_prop_access.indexes_to_generate():
                         expr = self.generate_expression(feature_prop_access.vector_index(dim))
-                        self.print(f"const double {acc_ref}_{dim} = {prop_name}[{expr}];")
+                        self.print(f"const {self.real_type()} {acc_ref}_{dim} = {prop_name}[{expr}];")
 
                 else:
-                    tkw = Types.c_keyword(feature_prop_access.type())
+                    tkw = Types.c_keyword(self.sim, feature_prop_access.type())
                     acc_index = self.generate_expression(feature_prop_access.index)
                     self.print(f"const {tkw} {acc_ref} = {prop_name}[{acc_index}];")
 
@@ -326,9 +329,9 @@ class CGen:
                 if not prop_access.is_scalar():
                     for dim in prop_access.indexes_to_generate():
                         expr = self.generate_expression(prop_access.vector_index(dim))
-                        self.print(f"const double {acc_ref}_{dim} = {prop_name}[{expr}];")
+                        self.print(f"const {self.real_type()} {acc_ref}_{dim} = {prop_name}[{expr}];")
                 else:
-                    tkw = Types.c_keyword(prop_access.type())
+                    tkw = Types.c_keyword(self.sim, prop_access.type())
                     index_g = self.generate_expression(prop_access.index)
                     self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];")
 
@@ -336,7 +339,7 @@ class CGen:
                 quaternion = ast_node.elem
                 for i in quaternion.indexes_to_generate():
                     expr = self.generate_expression(quaternion.get_value(i))
-                    self.print(f"const double {quaternion.name()}_{i} = {expr};")
+                    self.print(f"const {self.real_type()} {quaternion.name()}_{i} = {expr};")
 
             if isinstance(ast_node.elem, QuaternionOp):
                 quat_op = ast_node.elem
@@ -346,9 +349,9 @@ class CGen:
                     operator = quat_op.operator()
 
                     if operator.is_unary():
-                        self.print(f"const double {quat_op.name()}_{dim} = {operator.symbol()}({lhs});")
+                        self.print(f"const {self.real_type()} {quat_op.name()}_{dim} = {operator.symbol()}({lhs});")
                     else:
-                        self.print(f"const double {quat_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
+                        self.print(f"const {self.real_type()} {quat_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
 
             if isinstance(ast_node.elem, ScalarOp):
                 scalar_op = ast_node.elem
@@ -356,7 +359,7 @@ class CGen:
                     lhs = self.generate_expression(scalar_op.lhs, scalar_op.mem)
                     rhs = self.generate_expression(scalar_op.rhs)
                     operator = scalar_op.operator()
-                    tkw = Types.c_keyword(scalar_op.type())
+                    tkw = Types.c_keyword(self.sim, scalar_op.type())
 
                     if operator.is_unary():
                         self.print(f"const {tkw} {scalar_op.name()} = {operator.symbol()}({lhs});")
@@ -372,25 +375,25 @@ class CGen:
                         cond = self.generate_expression(select.cond, index=dim)
                         expr_if = self.generate_expression(select.expr_if, index=dim)
                         expr_else = self.generate_expression(select.expr_else, index=dim)
-                        self.print(f"const double {acc_ref}_{dim} = ({cond}) ? ({expr_if}) : ({expr_else});")
+                        self.print(f"const {self.real_type()} {acc_ref}_{dim} = ({cond}) ? ({expr_if}) : ({expr_else});")
                 else:
                     cond = self.generate_expression(select.cond)
                     expr_if = self.generate_expression(select.expr_if)
                     expr_else = self.generate_expression(select.expr_else)
-                    tkw = Types.c_keyword(select.type())
+                    tkw = Types.c_keyword(self.sim, select.type())
                     self.print(f"const {tkw} {acc_ref} = ({cond}) ? ({expr_if}) : ({expr_else});")
 
             if isinstance(ast_node.elem, MathFunction):
                 math_func = ast_node.elem
                 params = ", ".join([str(self.generate_expression(p)) for p in math_func.parameters()])
-                tkw = Types.c_keyword(math_func.type())
+                tkw = Types.c_keyword(self.sim, math_func.type())
                 self.print(f"const {tkw} {math_func.name()} = {math_func.function_name()}({params});")
 
             if isinstance(ast_node.elem, Matrix):
                 matrix = ast_node.elem
                 for i in matrix.indexes_to_generate():
                     expr = self.generate_expression(matrix.get_value(i))
-                    self.print(f"const double {matrix.name()}_{i} = {expr};")
+                    self.print(f"const {self.real_type()} {matrix.name()}_{i} = {expr};")
 
             if isinstance(ast_node.elem, MatrixOp):
                 matrix_op = ast_node.elem
@@ -400,15 +403,15 @@ class CGen:
                     operator = vector_op.operator()
 
                     if operator.is_unary():
-                        self.print(f"const double {matrix_op.name()}_{dim} = {operator.symbol()}({lhs});")
+                        self.print(f"const {self.real_type()} {matrix_op.name()}_{dim} = {operator.symbol()}({lhs});")
                     else:
-                        self.print(f"const double {matrix_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
+                        self.print(f"const {self.real_type()} {matrix_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
 
             if isinstance(ast_node.elem, Vector):
                 vector = ast_node.elem
                 for dim in vector.indexes_to_generate():
                     expr = self.generate_expression(vector.get_value(dim))
-                    self.print(f"const double {vector.name()}_{dim} = {expr};")
+                    self.print(f"const {self.real_type()} {vector.name()}_{dim} = {expr};")
 
             if isinstance(ast_node.elem, VectorOp):
                 vector_op = ast_node.elem
@@ -418,9 +421,9 @@ class CGen:
                     operator = vector_op.operator()
 
                     if operator.is_unary():
-                        self.print(f"const double {vector_op.name()}_{dim} = {operator.symbol()}({lhs});")
+                        self.print(f"const {self.real_type()} {vector_op.name()}_{dim} = {operator.symbol()}({lhs});")
                     else:
-                        self.print(f"const double {vector_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
+                        self.print(f"const {self.real_type()} {vector_op.name()}_{dim} = {lhs} {operator.symbol()} {rhs};")
 
         if isinstance(ast_node, Branch):
             cond = self.generate_expression(ast_node.cond)
@@ -509,7 +512,7 @@ class CGen:
 
 
         if isinstance(ast_node, Malloc):
-            tkw = Types.c_keyword(ast_node.array.type())
+            tkw = Types.c_keyword(self.sim, ast_node.array.type())
             size = self.generate_expression(ast_node.size)
             array_name = ast_node.array.name()
 
@@ -607,7 +610,7 @@ class CGen:
             self.print(f"PAIRS_DEBUG(\"{ast_node.string}\\n\");")
 
         if isinstance(ast_node, Realloc):
-            tkw = Types.c_keyword(ast_node.array.type())
+            tkw = Types.c_keyword(self.sim, ast_node.array.type())
             size = self.generate_expression(ast_node.size)
             array_name = ast_node.array.name()
             self.print(f"{array_name} = ({tkw} *) realloc({array_name}, {size});")
@@ -618,7 +621,7 @@ class CGen:
             a = ast_node.array()
             ptr = a.name()
             d_ptr = f"d_{ptr}" if self.target.is_gpu() and a.device_flag else "nullptr"
-            tkw = Types.c_keyword(a.type())
+            tkw = Types.c_keyword(self.sim, a.type())
             size = self.generate_expression(ast_node.size())
 
             if a.is_static():
@@ -637,7 +640,7 @@ class CGen:
             p = ast_node.property()
             ptr = p.name()
             d_ptr = f"d_{ptr}" if self.target.is_gpu() and p.device_flag else "nullptr"
-            tkw = Types.c_keyword(p.type())
+            tkw = Types.c_keyword(self.sim, p.type())
             ptype = Types.c_property_keyword(p.type())
             assert ptype != "Prop_Invalid", "Invalid property type!"
 
@@ -656,7 +659,7 @@ class CGen:
             p = ast_node.property()
             ptr = p.name()
             d_ptr = f"d_{ptr}" if self.target.is_gpu() and p.device_flag else "nullptr"
-            tkw = Types.c_keyword(p.type())
+            tkw = Types.c_keyword(self.sim, p.type())
             ptype = Types.c_property_keyword(p.type())
             assert ptype != "Prop_Invalid", "Invalid property type!"
 
@@ -677,7 +680,7 @@ class CGen:
             d_ptr = f"&d_{ptr}" if self.target.is_gpu() and fp.device_flag else "nullptr"
             array_size = fp.array_size()
             nkinds = fp.feature().nkinds()
-            tkw = Types.c_keyword(fp.type())
+            tkw = Types.c_keyword(self.sim, fp.type())
             fptype = Types.c_property_keyword(fp.type())
             assert fptype != "Prop_Invalid", "Invalid feature property type!"
 
@@ -710,7 +713,7 @@ class CGen:
             #self.print(f"pairs->reallocArray({a.id()}, (void **) &{ptr}, (void **) &d_{ptr}, {size});")
 
         if isinstance(ast_node, DeclareVariable):
-            tkw = Types.c_keyword(ast_node.var.type())
+            tkw = Types.c_keyword(self.sim, ast_node.var.type())
 
             if ast_node.var.is_scalar():
                 var = self.generate_expression(ast_node.var)
@@ -772,7 +775,7 @@ class CGen:
             return f"{ast_node.name()}({params})"
 
         if isinstance(ast_node, Cast):
-            tkw = Types.c_keyword(ast_node.cast_type)
+            tkw = Types.c_keyword(self.sim, ast_node.cast_type)
             expr = self.generate_expression(ast_node.expr)
             return f"({tkw})({expr})"
 
@@ -808,7 +811,7 @@ class CGen:
                 return ast_node.value[index]
 
             if isinstance(ast_node.value, float) and math.isinf(ast_node.value):
-                return "std::numeric_limits<double>::infinity()"
+                return f"std::numeric_limits<{self.real_type()}>::infinity()"
 
             return ast_node.value
 
@@ -876,7 +879,7 @@ class CGen:
 
         if isinstance(ast_node, Sizeof):
             assert mem is False, "Sizeof expression is not lvalue!"
-            tkw = Types.c_keyword(ast_node.data_type)
+            tkw = Types.c_keyword(self.sim, ast_node.data_type)
             return f"sizeof({tkw})"
 
         if isinstance(ast_node, Select):
diff --git a/src/pairs/ir/cast.py b/src/pairs/ir/cast.py
index 399de63..fa2820d 100644
--- a/src/pairs/ir/cast.py
+++ b/src/pairs/ir/cast.py
@@ -19,6 +19,12 @@ class Cast(ASTTerm):
     def uint64(sim, expr):
         return Cast(sim, expr, Types.UInt64)
 
+    def real(sim, expr):
+        return Cast(sim, expr, Types.Real)
+
+    def float(sim, expr):
+        return Cast(sim, expr, Types.Float)
+
     def double(sim, expr):
         return Cast(sim, expr, Types.Double)
 
diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py
index a5d4208..ff54c94 100644
--- a/src/pairs/ir/lit.py
+++ b/src/pairs/ir/lit.py
@@ -22,7 +22,7 @@ class Lit(ASTTerm):
         else:
             scalar_mapping = {
                 int: Types.Int32,
-                float: Types.Double,
+                float: Types.Real,
                 bool: Types.Boolean,
                 str: Types.String,
             }
diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py
index 9dde4d2..e85aa06 100644
--- a/src/pairs/ir/math.py
+++ b/src/pairs/ir/math.py
@@ -62,7 +62,7 @@ class Sqrt(MathFunction):
         return f"Sqrt<{self._params}>"
 
     def function_name(self):
-        return "sqrt"
+        return "sqrt" if self.sim.use_double_precision() else "sqrtf"
 
     def type(self):
         return self._params[0].type()
@@ -92,7 +92,7 @@ class Sin(MathFunction):
         return f"Sin<{self._params}>"
 
     def function_name(self):
-        return "sin"
+        return "sin" if self.sim.use_double_precision() else "sinf"
 
     def type(self):
         return self._params[0].type()
@@ -107,7 +107,7 @@ class Cos(MathFunction):
         return f"Cos<{self._params}>"
 
     def function_name(self):
-        return "cos"
+        return "cos" if self.sim.use_double_precision() else "cosf"
 
     def type(self):
         return self._params[0].type()
@@ -123,7 +123,7 @@ class Ceil(MathFunction):
         return f"Ceil<{self._params}>"
 
     def function_name(self):
-        return "ceil"
+        return "ceil" if self.sim.use_double_precision() else "ceilf"
 
     def type(self):
         return Types.Int32
diff --git a/src/pairs/ir/matrices.py b/src/pairs/ir/matrices.py
index aff53ad..38e3341 100644
--- a/src/pairs/ir/matrices.py
+++ b/src/pairs/ir/matrices.py
@@ -73,7 +73,7 @@ class MatrixAccess(ASTTerm):
         return f"MatrixAccess<{self.expr}, {self.index}>"
 
     def type(self):
-        return Types.Double
+        return Types.Real
 
     def children(self):
         return [self.expr]
diff --git a/src/pairs/ir/quaternions.py b/src/pairs/ir/quaternions.py
index 5406973..1f2df2a 100644
--- a/src/pairs/ir/quaternions.py
+++ b/src/pairs/ir/quaternions.py
@@ -73,7 +73,7 @@ class QuaternionAccess(ASTTerm):
         return f"QuaternionAccess<{self.expr}, {self.index}>"
 
     def type(self):
-        return Types.Double
+        return Types.Real
 
     def children(self):
         return [self.expr]
diff --git a/src/pairs/ir/scalars.py b/src/pairs/ir/scalars.py
index 12c2462..f0e30f8 100644
--- a/src/pairs/ir/scalars.py
+++ b/src/pairs/ir/scalars.py
@@ -87,7 +87,7 @@ class ScalarOp(ASTTerm):
             return Types.Quaternion
 
         if Types.is_real(lhs_type) or Types.is_real(rhs_type):
-            return Types.Double
+            return Types.Real
 
         if Types.is_integer(lhs_type) or Types.is_integer(rhs_type):
             if isinstance(lhs, Lit) or Lit.is_literal(lhs):
diff --git a/src/pairs/ir/types.py b/src/pairs/ir/types.py
index 74a17d9..ea1d40c 100644
--- a/src/pairs/ir/types.py
+++ b/src/pairs/ir/types.py
@@ -3,19 +3,22 @@ class Types:
     Int32 = 0
     Int64 = 1
     UInt64 = 2
-    Float = 3
-    Double = 4
-    Boolean = 5
-    String = 6
-    Vector = 7
-    Array = 8
-    Matrix = 9
-    Quaternion = 10
+    Real = 3
+    Float = 4
+    Double = 5
+    Boolean = 6
+    String = 7
+    Vector = 8
+    Array = 9
+    Matrix = 10
+    Quaternion = 11
 
-    def c_keyword(t):
+    def c_keyword(sim, t):
+        real_kw = 'double' if sim.use_double_precision() else 'float'
         return (
-            'double' if t in (Types.Double, Types.Vector, Types.Matrix, Types.Quaternion)
+            real_kw if t in (Types.Real, Types.Vector, Types.Matrix, Types.Quaternion)
             else 'float' if t == Types.Float
+            else 'double' if t == Types.Double
             else 'int' if t == Types.Int32
             else 'long long int' if t == Types.Int64
             else 'unsigned long long int' if t == Types.UInt64
@@ -25,7 +28,7 @@ class Types:
 
     def c_property_keyword(t):
         return "Prop_Integer"      if t == Types.Int32 else \
-               "Prop_Float"        if t == Types.Double else \
+               "Prop_Real"         if t == Types.Real else \
                "Prop_Vector"       if t == Types.Vector else \
                "Prop_Matrix"       if t == Types.Matrix else \
                "Prop_Quaternion"   if t == Types.Quaternion else \
@@ -35,7 +38,7 @@ class Types:
         return t in (Types.Int32, Types.Int64, Types.UInt64)
 
     def is_real(t):
-        return t in (Types.Float, Types.Double)
+        return t in (Types.Float, Types.Double, Types.Real)
 
     def is_scalar(t):
         return t not in (Types.Vector, Types.Matrix, Types.Quaternion)
diff --git a/src/pairs/ir/vectors.py b/src/pairs/ir/vectors.py
index 53e5f57..b8d94b1 100644
--- a/src/pairs/ir/vectors.py
+++ b/src/pairs/ir/vectors.py
@@ -82,7 +82,7 @@ class VectorAccess(ASTTerm):
         return f"VectorAccess<{self.expr}, {self.index}>"
 
     def type(self):
-        return Types.Double
+        return Types.Real
 
     def children(self):
         return [self.expr]
diff --git a/src/pairs/mapping/keywords.py b/src/pairs/mapping/keywords.py
index b7379af..3754454 100644
--- a/src/pairs/mapping/keywords.py
+++ b/src/pairs/mapping/keywords.py
@@ -48,7 +48,7 @@ class Keywords:
     def keyword_sqrt(self, args):
         assert len(args) == 1, "sqrt() keyword requires one parameter!"
         value = args[0]
-        assert value.type() == Types.Double, "sqrt(): Value must be a real."
+        assert value.type() == Types.Real, "sqrt(): Value must be a real."
         return Sqrt(self.sim, value)
 
     def keyword_skip_when(self, args):
diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py
index 0618606..888cf2b 100644
--- a/src/pairs/sim/comm.py
+++ b/src/pairs/sim/comm.py
@@ -23,14 +23,14 @@ class Comm:
         self.neigh_capacity = sim.add_var('neigh_capacity', Types.Int32, 6)
         self.nsend          = sim.add_array('nsend', [self.neigh_capacity], Types.Int32)
         self.send_offsets   = sim.add_array('send_offsets', [self.neigh_capacity], Types.Int32)
-        self.send_buffer    = sim.add_array('send_buffer', [self.send_capacity, self.elem_capacity], Types.Double)
+        self.send_buffer    = sim.add_array('send_buffer', [self.send_capacity, self.elem_capacity], Types.Real)
         self.send_map       = sim.add_array('send_map', [self.send_capacity], Types.Int32)
         self.exchg_flag     = sim.add_array('exchg_flag', [sim.particle_capacity], Types.Int32)
         self.exchg_copy_to  = sim.add_array('exchg_copy_to', [self.send_capacity], Types.Int32)
         self.send_mult      = sim.add_array('send_mult', [self.send_capacity, sim.ndims()], Types.Int32)
         self.nrecv          = sim.add_array('nrecv', [self.neigh_capacity], Types.Int32)
         self.recv_offsets   = sim.add_array('recv_offsets', [self.neigh_capacity], Types.Int32)
-        self.recv_buffer    = sim.add_array('recv_buffer', [self.recv_capacity, self.elem_capacity], Types.Double)
+        self.recv_buffer    = sim.add_array('recv_buffer', [self.recv_capacity, self.elem_capacity], Types.Real)
         self.recv_map       = sim.add_array('recv_map', [self.recv_capacity], Types.Int32)
         self.recv_mult      = sim.add_array('recv_mult', [self.recv_capacity, sim.ndims()], Types.Int32)
 
@@ -227,7 +227,7 @@ class PackGhostParticles(Lowerable):
                     p_offset += nelems
 
                 else:
-                    cast_fn = lambda x: Cast(self.sim, x, Types.Double) if p.type() != Types.Double else x
+                    cast_fn = lambda x: Cast(self.sim, x, Types.Real) if p.type() != Types.Real else x
                     Assign(self.sim, send_buffer[i][p_offset], cast_fn(p[m]))
                     p_offset += 1
 
@@ -264,7 +264,7 @@ class UnpackGhostParticles(Lowerable):
                     p_offset += nelems
 
                 else:
-                    cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Double else x
+                    cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Real else x
                     Assign(self.sim, p[nlocal + i], cast_fn(recv_buffer[i][p_offset]))
                     p_offset += 1
 
diff --git a/src/pairs/sim/domain_partitioning.py b/src/pairs/sim/domain_partitioning.py
index 08ce9a4..654e98c 100644
--- a/src/pairs/sim/domain_partitioning.py
+++ b/src/pairs/sim/domain_partitioning.py
@@ -14,7 +14,7 @@ class DimensionRanges:
         self.sim = sim
         self.neighbor_ranks = sim.add_static_array('neighbor_ranks', [sim.ndims() * 2], Types.Int32)
         self.pbc            = sim.add_static_array('pbc', [sim.ndims() * 2], Types.Int32)
-        self.subdom         = sim.add_static_array('subdom', [sim.ndims() * 2], Types.Double)
+        self.subdom         = sim.add_static_array('subdom', [sim.ndims() * 2], Types.Real)
 
     def number_of_steps(self):
         return self.sim.ndims()
@@ -27,7 +27,7 @@ class DimensionRanges:
             return
 
         # Particles with one of the following flags are ignored
-        flags_to_exclude = (Flags.Infinite | Flags.Fixed | Flags.Global)
+        flags_to_exclude = (Flags.Infinite | Flags.Global)
 
         for i in For(self.sim, 0, self.sim.nlocal + self.sim.nghost):
             for _ in Filter(self.sim, ScalarOp.cmp(self.sim.particle_flags[i] & flags_to_exclude, 0)):
diff --git a/src/pairs/sim/grid.py b/src/pairs/sim/grid.py
index 805ade2..43b0a31 100644
--- a/src/pairs/sim/grid.py
+++ b/src/pairs/sim/grid.py
@@ -58,8 +58,8 @@ class MutableGrid(Grid):
         self.id = MutableGrid.last_id
         prefix = f"grid{self.id}_"
         super().__init__(sim, [
-            (sim.add_var(f"{prefix}d{d}_min", Types.Double),
-             sim.add_var(f"{prefix}d{d}_max", Types.Double))
+            (sim.add_var(f"{prefix}d{d}_min", Types.Real),
+             sim.add_var(f"{prefix}d{d}_max", Types.Real))
             for d in range(ndims)
         ])
 
diff --git a/src/pairs/sim/interaction.py b/src/pairs/sim/interaction.py
index 160029e..ba9c30b 100644
--- a/src/pairs/sim/interaction.py
+++ b/src/pairs/sim/interaction.py
@@ -93,8 +93,8 @@ class InteractionData:
         self._i = sim.add_symbol(Types.Int32)
         self._j = sim.add_symbol(Types.Int32)
         self._delta = sim.add_symbol(Types.Vector)
-        self._squared_distance = sim.add_symbol(Types.Double)
-        self._penetration_depth = sim.add_symbol(Types.Double)
+        self._squared_distance = sim.add_symbol(Types.Real)
+        self._penetration_depth = sim.add_symbol(Types.Real)
         self._contact_point = sim.add_symbol(Types.Vector)
         self._contact_normal = sim.add_symbol(Types.Vector)
         self._shape = shape
diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py
index 39ade23..3308dcb 100644
--- a/src/pairs/sim/simulation.py
+++ b/src/pairs/sim/simulation.py
@@ -32,7 +32,7 @@ from pairs.transformations import Transformations
 
 
 class Simulation:
-    def __init__(self, code_gen, dims=3, timesteps=100):
+    def __init__(self, code_gen, dims=3, timesteps=100, double_prec=False):
         self.code_gen = code_gen
         self.code_gen.assign_simulation(self)
         self.position_prop = None
@@ -65,6 +65,7 @@ class Simulation:
         self._check_properties_resize = False
         self._resizes_to_check = {}
         self._module_name = None
+        self._double_prec = double_prec
         self.dims = dims
         self.ntimesteps = timesteps
         self.expr_id = 0
@@ -76,6 +77,9 @@ class Simulation:
         self._dom_part = DimensionRanges(self)
         self._pbc = [True for _ in range(dims)]
 
+    def use_double_precision(self):
+        return self._double_prec
+
     def max_shapes(self):
         return 2
 
-- 
GitLab