diff --git a/examples/dem.py b/examples/dem.py index c7b2f0b2345eaa93d52cb774ac4b4c36d5f301d9..550c3902218c6f1aea246d5d89c42b67377a426a 100644 --- a/examples/dem.py +++ b/examples/dem.py @@ -3,6 +3,11 @@ import pairs import sys +def update_mass_and_inertia(i): + mass[i] = (4.0 / 3.0) * pi * radius[i] * radius[i] * radius[i] * densityParticle_SI + inv_inertia[i] = diagonal_matrix(1.0 / (0.4 * mass[i] * radius[i] * radius[i])) + + def linear_spring_dashpot(i, j): delta_ij = -penetration_depth(i, j) skip_when(delta_ij < 0.0) @@ -60,6 +65,11 @@ def linear_spring_dashpot(i, j): def euler(i): linear_velocity[i] += dt * force[i] / mass[i] position[i] += dt * linear_velocity[i] + wdot = rotation_matrix[i] * (inv_inertia[i] * torque[i]) * transposed(rotation_matrix[i]) + phi = angular_velocity[i] * dt + wdot * dt * dt + rotation_quat[i] = quaternion(phi, length(phi)) * rotation_quat[i] + rotation_matrix[i] = quaternion_to_rotation_matrix(rotation_quat[i]) + angular_velocity[i] += wdot * dt def gravity(i): @@ -72,28 +82,6 @@ target = sys.argv[1] if len(sys.argv[1]) > 1 else "none" if target != 'cpu' and target != 'gpu': print(f"Invalid target, use {cmd} <cpu/gpu>") - -# BedGeneration { -# domainSize_SI < 0.8, 0.015, 0.2 >; -# blocks < 3, 3, 1 >; -# diameter_SI 0.0029; -# gravity_SI 9.81; -# densityFluid_SI 1000; -# densityParticle_SI 2550; -# generationSpacing_SI 0.005; -# initialVelocity_SI 1; -# dt_SI 5e-5; -# frictionCoefficient 0.5; -# restitutionCoefficient 0.1; -# collisionTime_SI 5e-4; -# poissonsRatio 0.22; -# timeSteps 10000; -# visSpacing 100; -# outFileName spheres_out.dat; -# denseBottomLayer False; -# bottomLayerOffsetFactor 1.0; -#} - # Config file parameters domainSize_SI = [0.8, 0.015, 0.2] blocks = [3, 3, 1] @@ -131,7 +119,8 @@ dampingTan = math.sqrt(kappa) * dampingNorm frictionStatic = frictionCoefficient # TODO: check if this is correct frictionDynamic = frictionCoefficient -psim = pairs.simulation("dem", debug=True, timesteps=timeSteps) +psim = pairs.simulation("dem", timesteps=timeSteps) +#psim = pairs.simulation("dem", debug=True, timesteps=timeSteps) psim.add_position('position') psim.add_property('mass', pairs.double(), 1.0) psim.add_property('linear_velocity', pairs.vector()) @@ -140,6 +129,9 @@ psim.add_property('force', pairs.vector(), volatile=True) psim.add_property('torque', pairs.vector(), volatile=True) psim.add_property('radius', pairs.double(), 1.0) psim.add_property('normal', pairs.vector()) +psim.add_property('inv_inertia', pairs.matrix()) +psim.add_property('rotation_matrix', pairs.matrix()) +psim.add_property('rotation_quat', pairs.quaternion()) psim.add_feature('type', ntypes) psim.add_feature_property('type', 'stiffness_norm', pairs.double(), [stiffnessNorm for i in range(ntypes * ntypes)]) psim.add_feature_property('type', 'stiffness_tan', pairs.double(), [stiffnessTan for i in range(ntypes * ntypes)]) @@ -168,6 +160,9 @@ psim.read_particle_data( ['type', 'mass', 'position', 'normal', 'flags'], pairs.halfspace()) +psim.setup(update_mass_and_inertia, {'densityParticle_SI': densityParticle_SI, + 'pi': math.pi }) + psim.build_neighbor_lists(linkedCellWidth + skin) psim.vtk_output(f"output/dem_{target}", frequency=visSpacing) psim.compute(gravity, symbols={'densityParticle_SI': densityParticle_SI, diff --git a/src/pairs/__init__.py b/src/pairs/__init__.py index abaeba14a10b7ba59d5c4e0a9055d4edc2c20c62..ba01adc7f930a6fc72d2475d706a1daa6c2556b3 100644 --- a/src/pairs/__init__.py +++ b/src/pairs/__init__.py @@ -23,6 +23,12 @@ def double(): def vector(): return Types.Vector +def matrix(): + return Types.Matrix + +def quaternion(): + return Types.Quaternion + def sphere(): return Shapes.Sphere diff --git a/src/pairs/analysis/blocks.py b/src/pairs/analysis/blocks.py index 699abbe7748c7c00bc2981137071ace4c701c53f..0d2b06157b87a54330fab9894814ecc5226c14ac 100644 --- a/src/pairs/analysis/blocks.py +++ b/src/pairs/analysis/blocks.py @@ -83,6 +83,16 @@ class DiscoverBlockVariants(Visitor): self.visit_children(ast_node) self.visited_accesses.add(ast_node) + def visit_MatrixAccess(self, ast_node): + if ast_node not in self.visited_accesses: + self.visit_children(ast_node) + self.visited_accesses.add(ast_node) + + def visit_QuaternionAccess(self, ast_node): + if ast_node not in self.visited_accesses: + self.visit_children(ast_node) + self.visited_accesses.add(ast_node) + def visit_Var(self, ast_node): self.push_variant(ast_node) @@ -202,3 +212,19 @@ class DetermineExpressionsOwnership(Visitor): def visit_VectorOp(self, ast_node): self.visit_children(ast_node) self.update_ownership(ast_node) + + def visit_Matrix(self, ast_node): + self.visit_children(ast_node) + self.update_ownership(ast_node) + + def visit_MatrixOp(self, ast_node): + self.visit_children(ast_node) + self.update_ownership(ast_node) + + def visit_Quaternion(self, ast_node): + self.visit_children(ast_node) + self.update_ownership(ast_node) + + def visit_QuaternionOp(self, ast_node): + self.visit_children(ast_node) + self.update_ownership(ast_node) diff --git a/src/pairs/analysis/devices.py b/src/pairs/analysis/devices.py index 0c004e8eb969fcdb2e9259957f5208d63978523b..a3625af4257533bd0dabd2b88a6abb9bb250dc44 100644 --- a/src/pairs/analysis/devices.py +++ b/src/pairs/analysis/devices.py @@ -1,5 +1,7 @@ from pairs.ir.arrays import ArrayAccess from pairs.ir.scalars import ScalarOp +from pairs.ir.quaternions import QuaternionOp +from pairs.ir.matrices import MatrixOp from pairs.ir.visitor import Visitor from pairs.ir.vectors import VectorOp @@ -12,6 +14,8 @@ class FetchKernelReferences(Visitor): self.kernel_used_array_accesses = {} self.kernel_used_scalar_ops = {} self.kernel_used_vector_ops = {} + self.kernel_used_matrix_ops = {} + self.kernel_used_quat_ops = {} self.writing = False def visit_ArrayAccess(self, ast_node): @@ -50,12 +54,17 @@ class FetchKernelReferences(Visitor): self.kernel_used_array_accesses[kernel_id] = [] self.kernel_used_scalar_ops[kernel_id] = [] self.kernel_used_vector_ops[kernel_id] = [] + self.kernel_used_matrix_ops[kernel_id] = [] + self.kernel_used_quat_ops[kernel_id] = [] self.kernel_stack.append(ast_node) self.visit_children(ast_node) self.kernel_stack.pop() + ast_node.add_array_access([a for a in self.kernel_used_array_accesses[kernel_id] if a not in self.kernel_decls[kernel_id]]) ast_node.add_scalar_op([b for b in self.kernel_used_scalar_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) ast_node.add_vector_op([b for b in self.kernel_used_vector_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) + ast_node.add_matrix_op([b for b in self.kernel_used_matrix_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) + ast_node.add_quaternion_op([b for b in self.kernel_used_quat_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place]) def visit_PropertyAccess(self, ast_node): # Visit property and save current writing state @@ -88,7 +97,7 @@ class FetchKernelReferences(Visitor): self.writing = writing_state def visit_Decl(self, ast_node): - if isinstance(ast_node.elem, (ArrayAccess, ScalarOp, VectorOp)): + if isinstance(ast_node.elem, (ArrayAccess, ScalarOp, VectorOp, MatrixOp, QuaternionOp)): for k in self.kernel_stack: self.kernel_decls[k.kernel_id].append(ast_node.elem) @@ -105,6 +114,18 @@ class FetchKernelReferences(Visitor): self.visit_children(ast_node) + def visit_MatrixOp(self, ast_node): + for k in self.kernel_stack: + self.kernel_used_matrix_ops[k.kernel_id].append(ast_node) + + self.visit_children(ast_node) + + def visit_QuaternionOp(self, ast_node): + for k in self.kernel_stack: + self.kernel_used_quat_ops[k.kernel_id].append(ast_node) + + self.visit_children(ast_node) + def visit_Array(self, ast_node): for k in self.kernel_stack: k.add_array(ast_node, self.writing) diff --git a/src/pairs/analysis/expressions.py b/src/pairs/analysis/expressions.py index 8f7b8d3e871e4573d7b8c56ca4dd9dfb9f04e313..ca1b42f197b1c7420b3e327cbab794037ec44311 100644 --- a/src/pairs/analysis/expressions.py +++ b/src/pairs/analysis/expressions.py @@ -1,4 +1,6 @@ from pairs.ir.scalars import ScalarOp +from pairs.ir.quaternions import QuaternionOp +from pairs.ir.matrices import MatrixOp from pairs.ir.vectors import VectorOp from pairs.ir.visitor import Visitor @@ -52,6 +54,18 @@ class DetermineExpressionsTerminals(Visitor): def visit_VectorOp(self, ast_node): self.traverse_expression(ast_node) + def visit_Matrix(self, ast_node): + self.traverse_expression(ast_node) + + def visit_MatrixOp(self, ast_node): + self.traverse_expression(ast_node) + + def visit_Quaternion(self, ast_node): + self.traverse_expression(ast_node) + + def visit_QuaternionOp(self, ast_node): + self.traverse_expression(ast_node) + def visit_Array(self, ast_node): self.push_terminal(ast_node) @@ -87,13 +101,21 @@ class ResetInPlaceOperations(Visitor): ast_node.in_place = True self.visit_children(ast_node) + def visit_MatrixOp(self, ast_node): + ast_node.in_place = True + self.visit_children(ast_node) + + def visit_QuaternionOp(self, ast_node): + ast_node.in_place = True + self.visit_children(ast_node) + class DetermineInPlaceOperations(Visitor): def __init__(self, ast=None): super().__init__(ast) def visit_Decl(self, ast_node): - if isinstance(ast_node.elem, (ScalarOp, VectorOp)): + if isinstance(ast_node.elem, (ScalarOp, VectorOp, MatrixOp, QuaternionOp)): ast_node.elem.in_place = False diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 9fa1dc01320c6108b38ceb2d9007e2e4c0845f0c..289cc38cce5d04e13433399df3adb8c8072d040c 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -14,7 +14,9 @@ from pairs.ir.kernel import KernelLaunch from pairs.ir.layouts import Layouts from pairs.ir.lit import Lit from pairs.ir.loops import For, Iter, While, Continue +from pairs.ir.quaternions import Quaternion, QuaternionAccess, QuaternionOp from pairs.ir.math import MathFunction +from pairs.ir.matrices import Matrix, MatrixAccess, MatrixOp from pairs.ir.memory import Malloc, Realloc from pairs.ir.module import ModuleCall from pairs.ir.particle_attributes import ParticleAttributeList @@ -237,10 +239,10 @@ class CGen: self.print(f"{tkw} {ast_node.array.name()}[{size}];") if isinstance(ast_node, Assign): - if ast_node._dest.is_vector(): - for dim in range(self.sim.ndims()): - dest = self.generate_expression(ast_node._dest, mem=True, index=dim) - src = self.generate_expression(ast_node._src, index=dim) + if not Types.is_scalar(ast_node._dest.type()): + for e in range(Types.number_of_elements(self.sim, ast_node._dest.type())): + dest = self.generate_expression(ast_node._dest, mem=True, index=e) + src = self.generate_expression(ast_node._src, index=e) self.print(f"{dest} = {src};") else: @@ -328,6 +330,24 @@ class CGen: index_g = self.generate_expression(prop_access.index) self.print(f"const {tkw} {acc_ref} = {prop_name}[{index_g}];") + if isinstance(ast_node.elem, Quaternion): + quaternion = ast_node.elem + for i in quaternion.indexes_to_generate(): + expr = self.generate_expression(quaternion.get_value(i)) + self.print(f"const double q{quaternion.id()}_{i} = {expr};") + + if isinstance(ast_node.elem, QuaternionOp): + quat_op = ast_node.elem + for i in quat_op.indexes_to_generate(): + lhs = self.generate_expression(quat_op.lhs, quat_op.mem, index=dim) + rhs = self.generate_expression(quat_op.rhs, index=dim) + operator = quat_op.operator() + + if operator.is_unary(): + self.print(f"const double e{quat_op.id()}_{dim} = {operator.symbol()}({lhs});") + else: + self.print(f"const double e{quat_op.id()}_{dim} = {lhs} {operator.symbol()} {rhs};") + if isinstance(ast_node.elem, ScalarOp): scalar_op = ast_node.elem if scalar_op.inlined is False: @@ -365,6 +385,24 @@ class CGen: tkw = Types.c_keyword(math_func.type()) self.print(f"const {tkw} {acc_ref} = {math_func.function_name()}({params});") + if isinstance(ast_node.elem, Matrix): + matrix = ast_node.elem + for i in matrix.indexes_to_generate(): + expr = self.generate_expression(matrix.get_value(i)) + self.print(f"const double m{matrix.id()}_{i} = {expr};") + + if isinstance(ast_node.elem, MatrixOp): + matrix_op = ast_node.elem + for i in matrix_op.indexes_to_generate(): + lhs = self.generate_expression(matrix_op.lhs, matrix_op.mem, index=i) + rhs = self.generate_expression(matrix_op.rhs, index=i) + operator = vector_op.operator() + + if operator.is_unary(): + self.print(f"const double e{matrix_op.id()}_{dim} = {operator.symbol()}({lhs});") + else: + self.print(f"const double e{matrix_op.id()}_{dim} = {lhs} {operator.symbol()} {rhs};") + if isinstance(ast_node.elem, Vector): vector = ast_node.elem for dim in vector.indexes_to_generate(): @@ -599,17 +637,10 @@ class CGen: ptr = p.name() d_ptr = f"d_{ptr}" if self.target.is_gpu() and p.device_flag else "nullptr" tkw = Types.c_keyword(p.type()) - ptype = "Prop_Integer" if p.type() == Types.Int32 else \ - "Prop_Float" if p.type() == Types.Double else \ - "Prop_Vector" if p.type() == Types.Vector else \ - "Prop_Invalid" - + ptype = Types.c_property_keyword(p.type()) assert ptype != "Prop_Invalid", "Invalid property type!" - playout = "AoS" if p.layout() == Layouts.AoS else \ - "SoA" if p.layout() == Layouts.SoA else \ - "Invalid" - + playout = Layouts.c_keyword(p.layout()) sizes = ", ".join([str(self.generate_expression(ScalarOp.inline(size))) for size in ast_node.sizes()]) if self.target.is_gpu() and p.device_flag: @@ -625,17 +656,10 @@ class CGen: ptr = p.name() d_ptr = f"d_{ptr}" if self.target.is_gpu() and p.device_flag else "nullptr" tkw = Types.c_keyword(p.type()) - ptype = "Prop_Integer" if p.type() == Types.Int32 else \ - "Prop_Float" if p.type() == Types.Double else \ - "Prop_Vector" if p.type() == Types.Vector else \ - "Prop_Invalid" - + ptype = Types.c_property_keyword(p.type()) assert ptype != "Prop_Invalid", "Invalid property type!" - playout = "AoS" if p.layout() == Layouts.AoS else \ - "SoA" if p.layout() == Layouts.SoA else \ - "Invalid" - + playout = Layouts.c_keyword(p.layout()) sizes = ", ".join([str(self.generate_expression(ScalarOp.inline(size))) for size in ast_node.sizes()]) if self.target.is_gpu() and p.device_flag: @@ -653,11 +677,7 @@ class CGen: array_size = fp.array_size() nkinds = fp.feature().nkinds() tkw = Types.c_keyword(fp.type()) - fptype = "Prop_Integer" if fp.type() == Types.Int32 else \ - "Prop_Float" if fp.type() == Types.Double else \ - "Prop_Vector" if fp.type() == Types.Vector else \ - "Prop_Invalid" - + fptype = Types.c_property_keyword(fp.type()) assert fptype != "Prop_Invalid", "Invalid feature property type!" self.print(f"{tkw} {ptr}[{array_size}];") @@ -691,17 +711,18 @@ class CGen: if isinstance(ast_node, DeclareVariable): tkw = Types.c_keyword(ast_node.var.type()) - if ast_node.var.type() == Types.Vector: - for dim in range(self.sim.ndims()): - var = self.generate_expression(ast_node.var, index=dim) - init = self.generate_expression(ast_node.var.init_value(), index=dim) - self.print(f"{tkw} {var} = {init};") - - else: + if ast_node.var.is_scalar(): var = self.generate_expression(ast_node.var) init = self.generate_expression(ast_node.var.init_value()) self.print(f"{tkw} {var} = {init};") + else: + for i in range(Types.number_of_elements(self.sim, ast_node.var.type())): + var = self.generate_expression(ast_node.var, index=i) + init = self.generate_expression(ast_node.var.init_value(), index=i) + self.print(f"{tkw} {var} = {init};") + + if not self.kernel_context and self.target.is_gpu() and ast_node.var.device_flag: self.print(f"RuntimeVar<{tkw}> rv_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") #self.print(f"{tkw} *d_{ast_node.var.name()} = pairs->addDeviceVariable(&({ast_node.var.name()}));") @@ -781,8 +802,8 @@ class CGen: if ast_node.type() == Types.String: return f"\"{ast_node.value}\"" - if ast_node.type() == Types.Vector: - assert index is not None, "Index must be set for vector literals!" + if not ast_node.is_scalar(): + assert index is not None, "Index must be set for non-scalar literals." return ast_node.value[index] return ast_node.value @@ -846,7 +867,7 @@ class CGen: assert mem is False, "Select expression is not lvalue!" if ast_node.inlined is True: - assert ast_node.type() != Types.Vector, "Vector operations cannot be inlined!" + assert ast_node.is_scalar(), "Only scalar operations can be inlined!" cond = self.generate_expression(ast_node.cond, index=index) expr_if = self.generate_expression(ast_node.expr_if, index=index) expr_else = self.generate_expression(ast_node.expr_else, index=index) @@ -867,13 +888,35 @@ class CGen: if isinstance(ast_node, VectorAccess): return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index)) + if isinstance(ast_node, MatrixAccess): + return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index)) + + if isinstance(ast_node, QuaternionAccess): + return self.generate_expression(ast_node.expr, mem, self.generate_expression(ast_node.index)) + if isinstance(ast_node, Vector): assert index is not None, "Index must be set for vector." return f"v{ast_node.id()}_{index}" + if isinstance(ast_node, Matrix): + assert index is not None, "Index must be set for matrix." + return f"m{ast_node.id()}_{index}" + + if isinstance(ast_node, Quaternion): + assert index is not None, "Index must be set for quaternion." + return f"q{ast_node.id()}_{index}" + if isinstance(ast_node, VectorOp): assert index is not None, "Index must be set for vector operation." return f"e{ast_node.id()}_{index}" + if isinstance(ast_node, MatrixOp): + assert index is not None, "Index must be set for matrix operation." + return f"e{ast_node.id()}_{index}" + + if isinstance(ast_node, QuaternionOp): + assert index is not None, "Index must be set for quaternion operation." + return f"e{ast_node.id()}_{index}" + if isinstance(ast_node, ZeroVector): return "0.0" diff --git a/src/pairs/ir/accessor_class.py b/src/pairs/ir/accessor_class.py new file mode 100644 index 0000000000000000000000000000000000000000..746a05caa96fa1d4ab80f6f601e368b65e371d2e --- /dev/null +++ b/src/pairs/ir/accessor_class.py @@ -0,0 +1,12 @@ +from pairs.ir.matrices import MatrixAccess +from pairs.ir.quaternions import QuaternionAccess +from pairs.ir.types import Types +from pairs.ir.vectors import VectorAccess + + +class AccessorClass: + def from_type(t): + return VectorAccess if t == Types.Vector else \ + MatrixAccess if t == Types.Matrix else \ + QuaternionAccess if t == Types.Quaternion else \ + None diff --git a/src/pairs/ir/assign.py b/src/pairs/ir/assign.py index 30f1a1b5e85ffa7091e6438f69f6672142e78185..6d73c54b93802a5d7f1f519ffd8696306843d055 100644 --- a/src/pairs/ir/assign.py +++ b/src/pairs/ir/assign.py @@ -10,14 +10,16 @@ class Assign(ASTNode): self._dest = dest self._src = Lit.cvt(sim, src) - # When vector assignments occur, all indexes for the dest + # When non-scalar assignments occur, all indexes for the dest # and source terms must be generated - if isinstance(self._dest, ASTTerm) and self._dest.type() == Types.Vector: - for dim in range(sim.ndims()): - self._dest.add_index_to_generate(dim) + if isinstance(self._dest, ASTTerm) and not Types.is_scalar(self._dest.type()): + for elem in range(Types.number_of_elements(self.sim, self._dest.type())): + self._dest.add_index_to_generate(elem) - if isinstance(self._src, ASTTerm) and self._src.type() == Types.Vector: - self._src.add_index_to_generate(dim) + if isinstance(self._src, ASTTerm) and not Types.is_scalar(self._src.type()): + assert self._dest.type() == self._src.type(), \ + "Non-scalar types must match for assignments." + self._src.add_index_to_generate(elem) sim.add_statement(self) diff --git a/src/pairs/ir/ast_term.py b/src/pairs/ir/ast_term.py index be0dd1ba00f2293ebf7875add53056bcb4398dca..778bdfc01da3d1d0476cd77e324a94e693ed6463 100644 --- a/src/pairs/ir/ast_term.py +++ b/src/pairs/ir/ast_term.py @@ -18,6 +18,9 @@ class ASTTerm(ASTNode): def __sub__(self, other): return self._class_type(self.sim, self, other, Operators.Sub) + def __rsub__(self, other): + return self._class_type(self.sim, other, self, Operators.Sub) + def __mul__(self, other): return self._class_type(self.sim, self, other, Operators.Mul) @@ -69,9 +72,18 @@ class ASTTerm(ASTNode): def inv(self): return self._class_type(self.sim, 1.0, self, Operators.Div) + def is_scalar(self): + return self.type() not in [Types.Vector, Types.Matrix, Types.Quaternion] + def is_vector(self): return self.type() == Types.Vector + def is_matrix(self): + return self.type() == Types.Matrix + + def is_quaternion(self): + return self.type() == Types.Quaternion + def indexes_to_generate(self): return self._indexes_to_generate @@ -81,5 +93,5 @@ class ASTTerm(ASTNode): self._indexes_to_generate.add(integer_index) for child in self.children(): - if isinstance(child, ASTTerm) and child.type() == Types.Vector: + if isinstance(child, ASTTerm) and not Types.is_scalar(child.type()): child.add_index_to_generate(integer_index) diff --git a/src/pairs/ir/features.py b/src/pairs/ir/features.py index 4fc25b63a4928f6da6d367cc7202db73686db7f0..5a2a29270a5ce788b49b1c7f2ae9a40853bebbb1 100644 --- a/src/pairs/ir/features.py +++ b/src/pairs/ir/features.py @@ -1,11 +1,12 @@ +from pairs.ir.accessor_class import AccessorClass from pairs.ir.ast_node import ASTNode from pairs.ir.ast_term import ASTTerm from pairs.ir.declaration import Decl from pairs.ir.scalars import ScalarOp from pairs.ir.layouts import Layouts from pairs.ir.lit import Lit +from pairs.ir.operator_class import OperatorClass from pairs.ir.types import Types -from pairs.ir.vectors import VectorAccess, VectorOp class Features: @@ -122,15 +123,16 @@ class FeatureProperty(ASTNode): return self.feature_prop_layout def ndims(self): - return 1 if self.feature_prop_type != Types.Vector else 2 + return 1 if Types.is_scalar(self.prop_type) else 2 def sizes(self): - return [self.feature_prop_feature.nkinds()] if self.feature_prop_type != Types.Vector \ - else [self.sim.ndims(), self.feature_prop_feature.nkinds()] + return [self.feature_prop_feature.nkinds()] if Types.is_scalar(self.feature_prop_type) \ + else [Types.number_of_elements(self.sim, self.feature_prop_type), + self.feature_prop_feature.nkinds()] def array_size(self): nelems = self.feature_prop_feature.nkinds() * \ - (1 if self.feature_prop_type != Types.Vector else self.sim.ndims()) + Types.number_of_elements(self.sim, self.feature_prop_type) return nelems * nelems def __getitem__(self, expr): @@ -146,7 +148,7 @@ class FeaturePropertyAccess(ASTTerm): def __init__(self, sim, feature_prop, index): assert isinstance(index, tuple), "Two indexes must be used for feature property access!" - super().__init__(sim, ScalarOp if feature_prop.type() != Types.Vector else VectorOp) + super().__init__(sim, OperatorClass.from_type(feature_prop.type())) self.acc_id = FeaturePropertyAccess.new_id() self.feature_prop = feature_prop feature = self.feature_prop.feature() @@ -155,15 +157,15 @@ class FeaturePropertyAccess(ASTTerm): self.terminals = set() self.vector_indexes = {} - if feature_prop.type() == Types.Vector: + if not Types.is_scalar(feature_prop.type()): sizes = feature_prop.sizes() layout = feature_prop.layout() - for dim in range(self.sim.ndims()): + for elem in range(Types.number_of_elements(feature_prop.type())): if layout == Layouts.AoS: - self.vector_indexes[dim] = self.index * sizes[0] + dim + self.vector_indexes[elem] = self.index * sizes[0] + elem elif layout == Layouts.SoA: - self.vector_indexes[dim] = dim * sizes[1] + self.index + self.vector_indexes[elem] = elem * sizes[1] + self.index else: raise Exception("Invalid data layout.") @@ -194,7 +196,8 @@ class FeaturePropertyAccess(ASTTerm): def __getitem__(self, index): super().__getitem__(index) - return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) + _acc_class = AccessorClass.from_type(self.feature_prop.type()) + return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) class RegisterFeatureProperty(ASTNode): diff --git a/src/pairs/ir/kernel.py b/src/pairs/ir/kernel.py index 8e1e6bdf446b51df98bf4c7ceb35cc61d90302a2..bb193bae26bd2f93344d1ddec3c19d16e4d3a6de 100644 --- a/src/pairs/ir/kernel.py +++ b/src/pairs/ir/kernel.py @@ -3,7 +3,9 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.scalars import ScalarOp from pairs.ir.features import FeatureProperty from pairs.ir.lit import Lit +from pairs.ir.matrices import MatrixOp from pairs.ir.properties import Property, ContactProperty +from pairs.ir.quaternions import QuaternionOp from pairs.ir.variables import Var from pairs.ir.vectors import VectorOp @@ -84,7 +86,7 @@ class Kernel(ASTNode): array_list = array if isinstance(array, list) else [array] character = 'w' if write else 'r' for a in array_list: - assert isinstance(a, Array), "Kernel.add_array(): given element is not of type Array!" + assert isinstance(a, Array), "Kernel.add_array(): Element is not of type Array." self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character def add_variable(self, variable, write=False): @@ -92,47 +94,59 @@ class Kernel(ASTNode): character = 'w' if write else 'r' for v in variable_list: if not v.temporary(): - assert isinstance(v, Var), "Kernel.add_variable(): given element is not of type Var!" + assert isinstance(v, Var), "Kernel.add_variable(): Element is not of type Var." self._variables[v] = character if v not in self._variables else self._variables[v] + character def add_property(self, prop, write=False): prop_list = prop if isinstance(prop, list) else [prop] character = 'w' if write else 'r' for p in prop_list: - assert isinstance(p, Property), "Kernel.add_property(): given element is not of type Property!" + assert isinstance(p, Property), "Kernel.add_property(): Element is not of type Property." self._properties[p] = character if p not in self._properties else self._properties[p] + character def add_contact_property(self, contact_prop, write=False): contact_prop_list = contact_prop if isinstance(contact_prop, list) else [contact_prop] character = 'w' if write else 'r' for cp in contact_prop_list: - assert isinstance(cp, ContactProperty), "Kernel.add_contact_property(): given element is not of type ContactProperty!" + assert isinstance(cp, ContactProperty), "Kernel.add_contact_property(): Element is not of type ContactProperty." self._contact_properties[cp] = character if cp not in self._contact_properties else self._contact_properties[cp] + character def add_feature_property(self, feature_prop): feature_prop_list = feature_prop if isinstance(feature_prop, list) else [feature_prop] for fp in feature_prop_list: - assert isinstance(fp, FeatureProperty), "Kernel.add_feature_property(): given element is not of type FeatureProperty!" + assert isinstance(fp, FeatureProperty), "Kernel.add_feature_property(): Element is not of type FeatureProperty." self._feature_properties[fp] = 'r' def add_array_access(self, array_access): array_access_list = array_access if isinstance(array_access, list) else [array_access] for a in array_access_list: - assert isinstance(a, ArrayAccess), "Kernel.add_array_access(): given element is not of type ArrayAccess!" + assert isinstance(a, ArrayAccess), "Kernel.add_array_access(): Element is not of type ArrayAccess." self._array_accesses.add(a) def add_scalar_op(self, scalar_op): scalar_op_list = scalar_op if isinstance(scalar_op, list) else [scalar_op] for b in scalar_op_list: - assert isinstance(b, ScalarOp), "Kernel.add_scalar_op(): given element is not of type ScalarOp!" + assert isinstance(b, ScalarOp), "Kernel.add_scalar_op(): Element is not of type ScalarOp." self._scalar_ops.append(b) def add_vector_op(self, vector_op): vector_op_list = vector_op if isinstance(vector_op, list) else [vector_op] for b in vector_op_list: - assert isinstance(b, VectorOp), "Kernel.add_vector_op(): given element is not of type VectorOp!" + assert isinstance(b, VectorOp), "Kernel.add_vector_op(): Element is not of type VectorOp." self._vector_ops.append(b) + def add_matrix_op(self, matrix_op): + matrix_op_list = matrix_op if isinstance(matrix_op, list) else [matrix_op] + for b in matrix_op_list: + assert isinstance(b, MatrixOp), "Kernel.add_matrix_op(): Element is not of type MatrixOp." + self._matrix_ops.append(b) + + def add_quaternion_op(self, quat_op): + quat_op_list = vector_op if isinstance(quat_op, list) else [quat_op] + for b in vector_op_list: + assert isinstance(b, QuaternionOp), "Kernel.add_quaternion_op(): Element is not of type QuaternionOp." + self._quat_ops.append(b) + def children(self): return [self._block] diff --git a/src/pairs/ir/layouts.py b/src/pairs/ir/layouts.py index 87a0b45c1694baa4637bed5c84ad01be1cccf746..6f1cac9dc870ff5c120a4f50576258d814922e8b 100644 --- a/src/pairs/ir/layouts.py +++ b/src/pairs/ir/layouts.py @@ -2,3 +2,8 @@ class Layouts: Invalid = -1 AoS = 0 SoA = 1 + + def c_keyword(layout): + return "AoS" if layout == Layouts.AoS else \ + "SoA" if layout == Layouts.SoA else \ + "Invalid" diff --git a/src/pairs/ir/lit.py b/src/pairs/ir/lit.py index 241000723f510d0ec1fde04ee509c1f79b8e085a..a5d420894259cb28b7e26e30ae328aa5c9a85e77 100644 --- a/src/pairs/ir/lit.py +++ b/src/pairs/ir/lit.py @@ -10,20 +10,28 @@ class Lit(ASTTerm): return Lit(sim, a) if Lit.is_literal(a) else a def __init__(self, sim, value): - type_mapping = { - int: Types.Int32, - float: Types.Double, - bool: Types.Boolean, - str: Types.String, - list: Types.Vector - } - - self.lit_type = type_mapping.get(type(value), Types.Invalid) - assert self.lit_type != Types.Invalid, "Invalid literal type!" - - from pairs.ir.scalars import ScalarOp - from pairs.ir.vectors import VectorOp - super().__init__(sim, VectorOp if self.lit_type == Types.Vector else ScalarOp) + if isinstance(value, list): + non_scalar_mapping = { + sim.ndims(): Types.Vector, + sim.ndims() * sim.ndims(): Types.Matrix, + sim.ndims() + 1: Types.Quaternion + } + + self.lit_type = non_scalar_mapping.get(len(value), Types.Invalid) + + else: + scalar_mapping = { + int: Types.Int32, + float: Types.Double, + bool: Types.Boolean, + str: Types.String, + } + + self.lit_type = scalar_mapping.get(type(value), Types.Invalid) + + assert self.lit_type != Types.Invalid, "Invalid literal type." + from pairs.ir.operator_class import OperatorClass + super().__init__(sim, OperatorClass.from_type(self.lit_type)) self.value = value def __str__(self): diff --git a/src/pairs/ir/math.py b/src/pairs/ir/math.py index 70a893d7b9d58cdfe73732cb8e664de955156e7c..b1cab83a481d578e719445b2c4da64fa9a6141fa 100644 --- a/src/pairs/ir/math.py +++ b/src/pairs/ir/math.py @@ -65,6 +65,51 @@ class Sqrt(MathFunction): return self._params[0].type() +class Abs(MathFunction): + def __init__(self, sim, expr): + super().__init__(sim) + self._params = [expr] + + def __str__(self): + return f"Abs<{self._params}>" + + def function_name(self): + return "fabs" + + def type(self): + return self._params[0].type() + + +class Sin(MathFunction): + def __init__(self, sim, expr): + super().__init__(sim) + self._params = [expr] + + def __str__(self): + return f"Sin<{self._params}>" + + def function_name(self): + return "sin" + + def type(self): + return self._params[0].type() + + +class Cos(MathFunction): + def __init__(self, sim, expr): + super().__init__(sim) + self._params = [expr] + + def __str__(self): + return f"Cos<{self._params}>" + + def function_name(self): + return "cos" + + def type(self): + return self._params[0].type() + + class Ceil(MathFunction): def __init__(self, sim, expr): assert Types.is_real(expr.type()), "Expression must be of real type!" diff --git a/src/pairs/ir/matrices.py b/src/pairs/ir/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..88baf50dd034287e3791682e8521c4cbaf74d41b --- /dev/null +++ b/src/pairs/ir/matrices.py @@ -0,0 +1,113 @@ +from pairs.ir.ast_node import ASTNode +from pairs.ir.ast_term import ASTTerm +from pairs.ir.scalars import ScalarOp +from pairs.ir.lit import Lit +from pairs.ir.types import Types + + +class MatrixOp(ASTTerm): + last_matrix_op = 0 + + def new_id(): + MatrixOp.last_matrix_op += 1 + return MatrixOp.last_matrix_op - 1 + + def __init__(self, sim, lhs, rhs, op, mem=False): + assert lhs.type() == Types.Matrix or rhs.type() == Types.Matrix, \ + "MatrixOp(): At least one matrix operand is required." + super().__init__(sim, MatrixOp) + self._id = MatrixOp.new_id() + self.lhs = Lit.cvt(sim, lhs) + self.rhs = Lit.cvt(sim, rhs) + self.op = op + self.mem = mem + self.in_place = False + self.terminals = set() + + def __str__(self): + a = f"MatrixOp<{self.lhs.id()}>" if isinstance(self.lhs, MatrixOp) else self.lhs + b = f"MatrixOp<{self.rhs.id()}>" if isinstance(self.rhs, MatrixOp) else self.rhs + return f"MatrixOp<id={self.id()}, {a} {self.op.symbol()} {b}>" + + def __getitem__(self, index): + return MatrixAccess(self.sim, self, Lit.cvt(self.sim, index)) + + def id(self): + return self._id + + def type(self): + return Types.Matrix + + def operator(self): + return self.op + + def reassign(self, lhs, rhs, op): + self.lhs = Lit.cvt(self.sim, lhs) + self.rhs = Lit.cvt(self.sim, rhs) + self.op = op + + def copy(self): + return MatrixOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + + def match(self, matrix_op): + return self.lhs == matrix_op.lhs and self.rhs == matrix_op.rhs and self.op == matrix_op.operator() + + def add_terminal(self, terminal): + self.terminals.add(terminal) + + def children(self): + return [self.lhs, self.rhs] if not self.op.is_unary() else [self.lhs] + + +class MatrixAccess(ASTTerm): + def __init__(self, sim, expr, index): + super().__init__(sim, ScalarOp) + self.expr = expr + self.index = index + expr.add_index_to_generate(index) + + def __str__(self): + return f"MatrixAccess<{self.expr}, {self.index}>" + + def type(self): + return Types.Double + + def children(self): + return [self.expr] + + +class Matrix(ASTTerm): + last_matrix = 0 + + def new_id(): + Matrix.last_matrix += 1 + return Matrix.last_matrix - 1 + + def __init__(self, sim, values): + assert isinstance(values, list) and len(values) == sim.ndims() * sim.ndims(), \ + "Matrix(): Given list is invalid!" + super().__init__(sim, MatrixOp) + self._id = Matrix.new_id() + self._values = [Lit.cvt(sim, v) for v in values] + self.terminals = set() + + def __str__(self): + return f"Matrix<{self._values}>" + + def __getitem__(self, index): + return MatrixAccess(self.sim, self, Lit.cvt(self.sim, index)) + + def id(self): + return self._id + + def type(self): + return Types.Matrix + + def get_value(self, dimension): + return self._values[dimension] + + def add_terminal(self, terminal): + self.terminals.add(terminal) + + def children(self): + return self._values diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index a08136bb5981356d412b46e7a67773c3156bab73..8c7cfd609d40d72506ed93af1ec1a976146d79ee 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -141,6 +141,22 @@ class Mutator: ast_node.vector_indexes = {d: self.mutate(i) for d, i in ast_node.vector_indexes.items()} return ast_node + def mutate_Quaternion(self, ast_node): + ast_node._values = [self.mutate(v) for v in ast_node._values] + return ast_node + + def mutate_QuaternionAccess(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_QuaternionOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + + if not ast_node.operator().is_unary(): + ast_node.rhs = self.mutate(ast_node.rhs) + + return ast_node + def mutate_Malloc(self, ast_node): ast_node.array = self.mutate(ast_node.array) ast_node.size = self.mutate(ast_node.size) @@ -150,6 +166,22 @@ class Mutator: ast_node._params = [self.mutate(p) for p in ast_node._params] return ast_node + def mutate_Matrix(self, ast_node): + ast_node._values = [self.mutate(v) for v in ast_node._values] + return ast_node + + def mutate_MatrixAccess(self, ast_node): + ast_node.expr = self.mutate(ast_node.expr) + return ast_node + + def mutate_MatrixOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + + if not ast_node.operator().is_unary(): + ast_node.rhs = self.mutate(ast_node.rhs) + + return ast_node + def mutate_Module(self, ast_node): ast_node._block = self.mutate(ast_node._block) return ast_node diff --git a/src/pairs/ir/operator_class.py b/src/pairs/ir/operator_class.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4fe6be310f540ddaa026227a2b661447b55e8c --- /dev/null +++ b/src/pairs/ir/operator_class.py @@ -0,0 +1,25 @@ +from pairs.ir.matrices import MatrixOp +from pairs.ir.quaternions import QuaternionOp +from pairs.ir.scalars import ScalarOp +from pairs.ir.types import Types +from pairs.ir.vectors import VectorOp + + +class OperatorClass: + def from_type(t): + return VectorOp if t == Types.Vector else \ + MatrixOp if t == Types.Matrix else \ + QuaternionOp if t == Types.Quaternion else \ + ScalarOp + + def from_type_list(type_list): + if Types.Quaternion in type_list: + return OperatorClass.from_type(Types.Quaternion) + + if Types.Matrix in type_list: + return OperatorClass.from_type(Types.Matrix) + + if Types.Vector in type_list: + return OperatorClass.from_type(Types.Vector) + + return OperatorClass.from_type(type_list[0]) diff --git a/src/pairs/ir/properties.py b/src/pairs/ir/properties.py index 712156ef1c7a0fbd06b50b0e0b259a9e2e5a2542..396510df2d850afdb535ce517003785b1040947c 100644 --- a/src/pairs/ir/properties.py +++ b/src/pairs/ir/properties.py @@ -1,11 +1,12 @@ +from pairs.ir.accessor_class import AccessorClass from pairs.ir.ast_node import ASTNode from pairs.ir.ast_term import ASTTerm from pairs.ir.declaration import Decl from pairs.ir.layouts import Layouts from pairs.ir.lit import Lit +from pairs.ir.operator_class import OperatorClass from pairs.ir.scalars import ScalarOp from pairs.ir.types import Types -from pairs.ir.vectors import VectorAccess, VectorOp class Properties: @@ -79,11 +80,11 @@ class Property(ASTNode): return self.default_value def ndims(self): - return 1 if self.prop_type != Types.Vector else 2 + return 1 if Types.is_scalar(self.prop_type) else 2 def sizes(self): - return [self.sim.particle_capacity] if self.prop_type != Types.Vector \ - else [self.sim.ndims(), self.sim.particle_capacity] + return [self.sim.particle_capacity] if Types.is_scalar(self.prop_type) \ + else [Types.number_of_elements(self.sim, self.prop_type), self.sim.particle_capacity] class PropertyAccess(ASTTerm): @@ -94,7 +95,7 @@ class PropertyAccess(ASTTerm): return PropertyAccess.last_prop_acc - 1 def __init__(self, sim, prop, index): - super().__init__(sim, ScalarOp if prop.type() != Types.Vector else VectorOp) + super().__init__(sim, OperatorClass.from_type(prop.type())) self.acc_id = PropertyAccess.new_id() self.prop = prop self.index = Lit.cvt(sim, index) @@ -102,15 +103,15 @@ class PropertyAccess(ASTTerm): self.terminals = set() self.vector_indexes = {} - if prop.type() == Types.Vector: + if not Types.is_scalar(prop.type()): sizes = prop.sizes() layout = prop.layout() - for dim in range(self.sim.ndims()): + for elem in range(Types.number_of_elements(sim, prop.type())): if layout == Layouts.AoS: - self.vector_indexes[dim] = self.index * sizes[0] + dim + self.vector_indexes[elem] = self.index * sizes[0] + elem elif layout == Layouts.SoA: - self.vector_indexes[dim] = dim * sizes[1] + self.index + self.vector_indexes[elem] = elem * sizes[1] + self.index else: raise Exception("Invalid data layout.") @@ -118,7 +119,8 @@ class PropertyAccess(ASTTerm): return f"PropertyAccess<{self.prop}, {self.index}>" def __getitem__(self, index): - return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) + _acc_class = AccessorClass.from_type(self.prop.type()) + return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) def copy(self): return PropertyAccess(self.sim, self.prop, self.index) @@ -239,12 +241,12 @@ class ContactProperty(ASTNode): return self.contact_prop_default def ndims(self): - return 1 if self.contact_prop_type != Types.Vector else 2 + return 1 if Types.is_scalar(self.contact_prop_type) else 2 def sizes(self): neighbor_list_sizes = [self.sim.particle_capacity, self.sim.neighbor_capacity] - return neighbor_list_sizes if self.contact_prop_type != Types.Vector \ - else [self.sim.ndims()] + neighbor_list_sizes + return neighbor_list_sizes if Types.is_scalar(self.contact_prop_type) \ + else [Types.number_of_elements(self.sim, self.contact_prop_type)] + neighbor_list_sizes class ContactPropertyAccess(ASTTerm): @@ -256,7 +258,7 @@ class ContactPropertyAccess(ASTTerm): def __init__(self, sim, contact_prop, index): assert isinstance(index, tuple), "Two indexes must be used for contact property access!" - super().__init__(sim, ScalarOp if contact_prop.type() != Types.Vector else VectorOp) + super().__init__(sim, OperatorClass.from_type(contact_prop.type())) self.acc_id = ContactPropertyAccess.new_id() self.contact_prop = contact_prop self.index = index[0] * self.sim.neighbor_capacity + index[1] @@ -264,15 +266,15 @@ class ContactPropertyAccess(ASTTerm): self.terminals = set() self.vector_indexes = {} - if contact_prop.type() == Types.Vector: + if not Types.is_scalar(contact_prop.type()): sizes = contact_prop.sizes() layout = contact_prop.layout() - for dim in range(self.sim.ndims()): + for elem in range(Types.number_of_elements(sim, contact_prop.type())): if layout == Layouts.AoS: - self.vector_indexes[dim] = self.index * sizes[0] + dim + self.vector_indexes[elem] = self.index * sizes[0] + elem elif layout == Layouts.SoA: - self.vector_indexes[dim] = dim * sizes[1] + self.index + self.vector_indexes[elem] = elem * sizes[1] + self.index else: raise Exception("Invalid data layout.") @@ -280,7 +282,8 @@ class ContactPropertyAccess(ASTTerm): return f"ContactPropertyAccess<{self.contact_prop}, {self.index}>" def __getitem__(self, index): - return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) + _acc_class = AccessorClass.from_type(self.contact_prop.type()) + return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) def copy(self): return ContactPropertyAccess(self.sim, self.contact_prop, self.index) diff --git a/src/pairs/ir/quaternions.py b/src/pairs/ir/quaternions.py new file mode 100644 index 0000000000000000000000000000000000000000..ef9a7e02f27a375982a87ddf0eb2b24bbb77c6cf --- /dev/null +++ b/src/pairs/ir/quaternions.py @@ -0,0 +1,113 @@ +from pairs.ir.ast_node import ASTNode +from pairs.ir.ast_term import ASTTerm +from pairs.ir.scalars import ScalarOp +from pairs.ir.lit import Lit +from pairs.ir.types import Types + + +class QuaternionOp(ASTTerm): + last_quaternion_op = 0 + + def new_id(): + QuaternionOp.last_quaternion_op += 1 + return QuaternionOp.last_quaternion_op - 1 + + def __init__(self, sim, lhs, rhs, op, mem=False): + assert lhs.type() == Types.Quaternion or rhs.type() == Types.Quaternion, \ + "QuaternionOp(): At least one quaternion operand is required." + super().__init__(sim, QuaternionOp) + self._id = QuaternionOp.new_id() + self.lhs = Lit.cvt(sim, lhs) + self.rhs = Lit.cvt(sim, rhs) + self.op = op + self.mem = mem + self.in_place = False + self.terminals = set() + + def __str__(self): + a = f"QuaternionOp<{self.lhs.id()}>" if isinstance(self.lhs, QuaternionOp) else self.lhs + b = f"QuaternionOp<{self.rhs.id()}>" if isinstance(self.rhs, QuaternionOp) else self.rhs + return f"QuaternionOp<id={self.id()}, {a} {self.op.symbol()} {b}>" + + def __getitem__(self, index): + return QuaternionAccess(self.sim, self, Lit.cvt(self.sim, index)) + + def id(self): + return self._id + + def type(self): + return Types.Quaternion + + def operator(self): + return self.op + + def reassign(self, lhs, rhs, op): + self.lhs = Lit.cvt(self.sim, lhs) + self.rhs = Lit.cvt(self.sim, rhs) + self.op = op + + def copy(self): + return QuaternionOp(self.sim, self.lhs.copy(), self.rhs.copy(), self.op, self.mem) + + def match(self, quat_op): + return self.lhs == quat_op.lhs and self.rhs == quat_op.rhs and self.op == quat_op.operator() + + def add_terminal(self, terminal): + self.terminals.add(terminal) + + def children(self): + return [self.lhs, self.rhs] if not self.op.is_unary() else [self.lhs] + + +class QuaternionAccess(ASTTerm): + def __init__(self, sim, expr, index): + super().__init__(sim, ScalarOp) + self.expr = expr + self.index = index + expr.add_index_to_generate(index) + + def __str__(self): + return f"QuaternionAccess<{self.expr}, {self.index}>" + + def type(self): + return Types.Double + + def children(self): + return [self.expr] + + +class Quaternion(ASTTerm): + last_quat = 0 + + def new_id(): + Quaternion.last_quat += 1 + return Quaternion.last_quat - 1 + + def __init__(self, sim, values): + assert isinstance(values, list) and len(values) == sim.ndims() + 1, \ + "Quaternion(): Given list is invalid!" + super().__init__(sim, QuaternionOp) + self._id = Quaternion.new_id() + self._values = [Lit.cvt(sim, v) for v in values] + self.terminals = set() + + def __str__(self): + return f"Quaternion<{self._values}>" + + def __getitem__(self, index): + return QuaternionAccess(self.sim, self, Lit.cvt(self.sim, index)) + + def id(self): + return self._id + + def type(self): + return Types.Quaternion + + def get_value(self, dimension): + return self._values[dimension] + + def add_terminal(self, terminal): + self.terminals.add(terminal) + + def children(self): + return self._values diff --git a/src/pairs/ir/scalars.py b/src/pairs/ir/scalars.py index 929a559d319f4b295889f021bc9e467329b18641..1e33274122c07f25dde5114e9f10f11dda440c85 100644 --- a/src/pairs/ir/scalars.py +++ b/src/pairs/ir/scalars.py @@ -80,6 +80,12 @@ class ScalarOp(ASTTerm): if lhs_type == Types.Vector or rhs_type == Types.Vector: return Types.Vector + if lhs_type == Types.Matrix or rhs_type == Types.Matrix: + return Types.Matrix + + if lhs_type == Types.Quaternion or rhs_type == Types.Quaternion: + return Types.Quaternion + if Types.is_real(lhs_type) or Types.is_real(rhs_type): return Types.Double diff --git a/src/pairs/ir/select.py b/src/pairs/ir/select.py index 629ad12e6449e206bb62bf2ec0f2add913ae6159..9bc43e10d849c0aa60ea2c288a4adcd9f9e01c81 100644 --- a/src/pairs/ir/select.py +++ b/src/pairs/ir/select.py @@ -1,9 +1,10 @@ +from pairs.ir.accessor_class import AccessorClass from pairs.ir.ast_node import ASTNode from pairs.ir.ast_term import ASTTerm from pairs.ir.scalars import ScalarOp from pairs.ir.lit import Lit +from pairs.ir.operator_class import OperatorClass from pairs.ir.types import Types -from pairs.ir.vectors import VectorAccess, VectorOp class Select(ASTTerm): @@ -14,13 +15,13 @@ class Select(ASTTerm): return Select.last_select - 1 def __init__(self, sim, cond, expr_if, expr_else): - super().__init__(sim, ScalarOp if expr_if.type() != Types.Vector else VectorOp) - self.select_id = Select.new_id() + super().__init__(sim, OperatorClass.from_type(Lit.cvt(sim, expr_if).type())) self.cond = Lit.cvt(sim, cond) - #self.expr_if = ScalarOp.inline(Lit.cvt(sim, expr_if)) - #self.expr_else = ScalarOp.inline(Lit.cvt(sim, expr_else)) self.expr_if = Lit.cvt(sim, expr_if) self.expr_else = Lit.cvt(sim, expr_else) + self.select_id = Select.new_id() + #self.expr_if = ScalarOp.inline(Lit.cvt(sim, expr_if)) + #self.expr_else = ScalarOp.inline(Lit.cvt(sim, expr_else)) self.terminals = set() self.inlined = False assert self.expr_if.type() == self.expr_else.type(), "Select: expressions must be of same type!" @@ -28,6 +29,10 @@ class Select(ASTTerm): def __str__(self): return f"Select<{self.cond}, {self.expr_if}, {self.expr_else}>" + def __getitem__(self, index): + _acc_class = AccessorClass.from_type(self.expr_if.type()) + return _acc_class(self.sim, self, Lit.cvt(self.sim, index)) + def id(self): return self.select_id @@ -54,6 +59,3 @@ class Select(ASTTerm): def children(self): return [self.cond, self.expr_if, self.expr_else] - - def __getitem__(self, index): - return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) diff --git a/src/pairs/ir/symbols.py b/src/pairs/ir/symbols.py index 6f2182d85c70374e1a84a3e1e55f1209c4012d05..72b9487cfd40f89c7b54d5be04ee47a5d6325b5c 100644 --- a/src/pairs/ir/symbols.py +++ b/src/pairs/ir/symbols.py @@ -2,13 +2,13 @@ from pairs.ir.ast_node import ASTNode from pairs.ir.ast_term import ASTTerm from pairs.ir.scalars import ScalarOp from pairs.ir.lit import Lit +from pairs.ir.operator_class import OperatorClass from pairs.ir.types import Types -from pairs.ir.vectors import VectorAccess, VectorOp class Symbol(ASTTerm): def __init__(self, sim, sym_type): - super().__init__(sim, ScalarOp if sym_type != Types.Vector else VectorOp) + super().__init__(sim, OperatorClass.from_type(sym_type)) self.sym_type = sym_type self.assign_to = None @@ -22,14 +22,13 @@ class Symbol(ASTTerm): return self.sym_type def __getitem__(self, index): - #return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) return SymbolAccess(self.sim, self, Lit.cvt(self.sim, index)) class SymbolAccess(ASTTerm): def __init__(self, sim, symbol, index): - assert symbol.type() == Types.Vector, "Only vector symbols can be indexed!" - super().__init__(sim, ScalarOp if symbol.type() != Types.Vector else VectorOp) + assert not Types.is_scalar(symbol.type()), "Scalar symbols cannot be indexed." + super().__init__(sim, OperatorClass.from_type(symbol.type())) self._symbol = symbol self._index = index symbol.add_index_to_generate(index) @@ -44,7 +43,7 @@ class SymbolAccess(ASTTerm): return self._index def type(self): - if self._symbol.type() == Types.Vector: + if not Types.is_scalar(self._symbol.type()): return Types.Float return self._symbol.type() diff --git a/src/pairs/ir/types.py b/src/pairs/ir/types.py index 92862ad050c75ee47bac9ac7becb2456df4780f9..8bf91b106138c8a3ad411c43e72e9c633c21e0c3 100644 --- a/src/pairs/ir/types.py +++ b/src/pairs/ir/types.py @@ -9,10 +9,12 @@ class Types: String = 6 Vector = 7 Array = 8 + Matrix = 9 + Quaternion = 10 def c_keyword(t): return ( - 'double' if t == Types.Double or t == Types.Vector + 'double' if t in (Types.Double, Types.Vector, Types.Matrix, Types.Quaternion) else 'float' if t == Types.Float else 'int' if t == Types.Int32 else 'long long int' if t == Types.Int64 @@ -21,8 +23,25 @@ class Types: else '<invalid type>' ) + def c_property_keyword(t): + ptype = "Prop_Integer" if t == Types.Int32 else \ + "Prop_Float" if t == Types.Double else \ + "Prop_Vector" if t == Types.Vector else \ + "Prop_Matrix" if t == Types.Matrix else \ + "Prop_Quaternion" if t == Types.Quaternion else \ + "Prop_Invalid" + def is_integer(t): - return t == Types.Int32 or t == Types.Int64 or t == Types.UInt64 + return t in (Types.Int32, Types.Int64, Types.UInt64) def is_real(t): - return t == Types.Float or t == Types.Double + return t in (Types.Float, Types.Double) + + def is_scalar(t): + return t not in (Types.Vector, Types.Matrix, Types.Quaternion) + + def number_of_elements(sim, t): + return sim.ndims() if t == Types.Vector else \ + sim.ndims() * sim.ndims() if t == Types.Matrix else \ + sim.ndims() + 1 if t == Types.Quaternion else \ + 1 diff --git a/src/pairs/ir/vectors.py b/src/pairs/ir/vectors.py index e6ed7c9b613f7dbb66fdd4288a499153cfe05134..825cb9716f91718794d33a2352db4dc2266cd97b 100644 --- a/src/pairs/ir/vectors.py +++ b/src/pairs/ir/vectors.py @@ -96,11 +96,11 @@ class Vector(ASTTerm): assert isinstance(values, list) and len(values) == sim.ndims(), "Vector(): Given list is invalid!" super().__init__(sim, VectorOp) self._id = Vector.new_id() - self._values = values + self._values = [Lit.cvt(sim, v) for v in values] self.terminals = set() def __str__(self): - return f"Vector<{self.values}>" + return f"Vector<{self._values}>" def __getitem__(self, index): return VectorAccess(self.sim, self, Lit.cvt(self.sim, index)) diff --git a/src/pairs/mapping/funcs.py b/src/pairs/mapping/funcs.py index c57a209bbe70718a9bb7115060662363476748da..8fde06758bf7d5515626962e7f30a191eb19cd02 100644 --- a/src/pairs/mapping/funcs.py +++ b/src/pairs/mapping/funcs.py @@ -5,9 +5,9 @@ from pairs.ir.branches import Branch, Filter from pairs.ir.lit import Lit from pairs.ir.loops import ParticleFor from pairs.ir.operators import Operators +from pairs.ir.operator_class import OperatorClass from pairs.ir.scalars import ScalarOp from pairs.ir.types import Types -from pairs.ir.vectors import VectorOp from pairs.mapping.keywords import Keywords from pairs.sim.flags import Flags from pairs.sim.interaction import ParticleInteraction @@ -101,7 +101,7 @@ class BuildParticleIR(ast.NodeVisitor): def visit_AugAssign(self, node): lhs = self.visit(node.target) rhs = self.visit(node.value) - op_class = VectorOp if Types.Vector in [lhs.type(), rhs.type()] else ScalarOp + op_class = OperatorClass.from_type_list([lhs.type(), rhs.type()]) bin_op = op_class(self.sim, lhs, rhs, BuildParticleIR.get_binary_op(node.op)) if isinstance(lhs, UndefinedSymbol): @@ -115,7 +115,16 @@ class BuildParticleIR(ast.NodeVisitor): assert not isinstance(lhs, UndefinedSymbol), f"Undefined lhs used in BinOp: {lhs.symbol_id}" rhs = self.visit(node.right) assert not isinstance(rhs, UndefinedSymbol), f"Undefined rhs used in BinOp: {rhs.symbol_id}" - op_class = VectorOp if Types.Vector in [lhs.type(), rhs.type()] else ScalarOp + operator = BuildParticleIR.get_binary_op(node.op) + + if operator == Operators.Mul: + if Types.Matrix in (lhs.type(), rhs.type()): + return self.keywords.keyword_matrix_multiplication([lhs, rhs]) + + if Types.Quaternion in (lhs.type(), rhs.type()): + return self.keywords.keyword_quaternion_multiplication([lhs, rhs]) + + op_class = OperatorClass.from_type_list([lhs.type(), rhs.type()]) return op_class(self.sim, lhs, rhs, BuildParticleIR.get_binary_op(node.op)) def visit_BoolOp(self, node): @@ -149,8 +158,9 @@ class BuildParticleIR(ast.NodeVisitor): lhs = self.visit(node.left) rhs = self.visit(node.comparators[0]) - op_class = VectorOp if Types.Vector in [lhs.type(), rhs.type()] else ScalarOp - return op_class(self.sim, lhs, rhs, BuildParticleIR.get_binary_op(node.ops[0])) + operator = BuildParticleIR.get_binary_op(node.ops[0]) + op_class = OperatorClass.from_type_list([lhs.type(), rhs.type()]) + return op_class(self.sim, lhs, rhs, operator) def visit_If(self, node): condition = self.visit(node.test) @@ -216,7 +226,7 @@ class BuildParticleIR(ast.NodeVisitor): operand = self.visit(node.operand) assert not isinstance(operand, UndefinedSymbol), \ f"Undefined operand used in UnaryOp: {operand.symbol_id}" - op_class = VectorOp if operand.type() == Types.Vector else ScalarOp + op_class = OperatorClass.from_type(operand.type()) return op_class(self.sim, operand, None, BuildParticleIR.get_unary_op(node.op)) @@ -264,3 +274,30 @@ def compute(sim, func, cutoff_radius=None, symbols={}): ir.visit(tree) sim.build_module_with_statements() + + +def setup(sim, func, symbols={}): + src = inspect.getsource(func) + tree = ast.parse(src, mode='exec') + + # Fetch function info + info = FetchParticleFuncInfo() + info.visit(tree) + params = info.params() + nparams = info.nparams() + + # Compute functions must have parameters + assert nparams == 1, "Number of parameters from setup functions must be one!" + + # Convert literal symbols + symbols = {symbol: Lit.cvt(sim, value) for symbol, value in symbols.items()} + + sim.init_block() + sim.module_name(func.__name__) + + for i in ParticleFor(sim): + ir = BuildParticleIR(sim, symbols) + ir.add_symbols({params[0]: i}) + ir.visit(tree) + + sim.build_module_with_statements() diff --git a/src/pairs/mapping/keywords.py b/src/pairs/mapping/keywords.py index 64d8f2ca25af4381b2f5e4dc936fdac91e4ef856..cf28e9cff2cb634ca1a31ac5a1b9626315ca9b8a 100644 --- a/src/pairs/mapping/keywords.py +++ b/src/pairs/mapping/keywords.py @@ -2,7 +2,9 @@ from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.lit import Lit from pairs.ir.loops import Continue -from pairs.ir.math import Sqrt +from pairs.ir.math import Abs, Cos, Sin, Sqrt +from pairs.ir.matrices import Matrix +from pairs.ir.quaternions import Quaternion from pairs.ir.select import Select from pairs.ir.types import Types from pairs.ir.vectors import Vector, ZeroVector @@ -90,3 +92,107 @@ class Keywords: def keyword_zero_vector(self, args): assert len(args) == 0, "zero_vector() keyword requires no parameter!" return ZeroVector(self.sim) + + def keyword_transposed(self, args): + assert len(args) == 1, "transposed() keyword requires one parameter!" + matrix = args[0] + assert matrix.type() == Types.Matrix, "tranposed(): Argument must be a matrix!" + return Matrix(self.sim, [ matrix[0], matrix[3], matrix[6], + matrix[1], matrix[4], matrix[7], + matrix[2], matrix[5], matrix[8] ]) + + def keyword_diagonal_matrix(self, args): + assert len(args) == 1, "diagonal_matrix() keyword requires one parameter!" + value = args[0] + nelems = Types.number_of_elements(self.sim, Types.Matrix) + return Matrix(self.sim, [value if i % (self.sim.ndims() + 1) == 0 else 0.0 \ + for i in range(nelems)]) + + def keyword_matrix_multiplication(self, args): + assert len(args) == 2, "matrix_multiplication() keyword requires two parameters!" + lhs = args[0] + rhs = args[1] + nelems = Types.number_of_elements(self.sim, Types.Matrix) + assert Types.Matrix in (lhs.type(), rhs.type()), \ + "matrix_multiplication(): At least one matrix is needed!" + + # Matrix * Matrix + if lhs.type() == rhs.type(): + return Matrix(self.sim, [ lhs[0] * rhs[0] + lhs[1] * rhs[3] + lhs[2] * rhs[6], + lhs[0] * rhs[1] + lhs[1] * rhs[4] + lhs[2] * rhs[7], + lhs[0] * rhs[2] + lhs[1] * rhs[5] + lhs[2] * rhs[8], + lhs[3] * rhs[0] + lhs[4] * rhs[3] + lhs[5] * rhs[6], + lhs[3] * rhs[1] + lhs[4] * rhs[4] + lhs[5] * rhs[7], + lhs[3] * rhs[2] + lhs[4] * rhs[5] + lhs[5] * rhs[8], + lhs[6] * rhs[0] + lhs[7] * rhs[3] + lhs[8] * rhs[6], + lhs[6] * rhs[1] + lhs[7] * rhs[4] + lhs[8] * rhs[7], + lhs[6] * rhs[2] + lhs[7] * rhs[5] + lhs[8] * rhs[8] ]) + + if Types.Vector in (lhs.type(), rhs.type()): + # Matrix * Vector + if lhs.type() == Types.Matrix: + return Vector(self.sim, [ lhs[0] * rhs[0] + lhs[1] * rhs[1] + lhs[2] * rhs[2], + lhs[3] * rhs[0] + lhs[4] * rhs[1] + lhs[5] * rhs[2], + lhs[6] * rhs[0] + lhs[7] * rhs[1] + lhs[8] * rhs[2] ]) + + # Vector * Matrix + else: + return Vector(self.sim, [ lhs[0] * rhs[0] + lhs[1] * rhs[1] + lhs[2] * rhs[2], + lhs[0] * rhs[3] + lhs[1] * rhs[4] + lhs[2] * rhs[5], + lhs[0] * rhs[6] + lhs[1] * rhs[7] + lhs[2] * rhs[8] ]) + + # Scalar * Matrix + if rhs.type() == Types.Matrix: + return Matrix(self.sim, [rhs[i] * lhs for i in range(nelems)]) + + # Matrix * Scalar + return Matrix(self.sim, [lhs[i] * rhs for i in range(nelems)]) + + def keyword_quaternion(self, args): + assert len(args) == 2, "quaternion() keyword requires two parameters!" + axis = args[0] + angle = args[1] + epsilon = 1e-6 + assert axis.type() == Types.Vector, "quaternion(): First argument must be a vector." + assert Types.is_real(angle.type()), "quaternion(): Second argument must be a real value." + + axis_length = self.keyword_length([axis]) + zero_cond = Abs(self.sim, axis_length) < epsilon or Abs(self.sim, angle) < epsilon + sina = Select(self.sim, zero_cond, 0.0, Sin(self.sim, angle * 0.5)) + cosa = Select(self.sim, zero_cond, 1.0, Cos(self.sim, angle * 0.5)) + axisN = axis * (1.0 / axis_length) + return Quaternion(self.sim, [cosa, sina * axisN[0], sina * axisN[1], sina * axisN[2]]) + + def keyword_quaternion_multiplication(self, args): + assert len(args) == 2, "quaternion_multiplication() keyword requires two parameters!" + lhs = args[0] + rhs = args[1] + assert lhs.type() == Types.Quaternion, \ + "quaternion_multiplication(): Left-hand side operator is not a quaternion!" + assert rhs.type() == Types.Quaternion, \ + "quaternion_multiplication(): Right-hand side operator is not a quaternion!" + + r = lhs[0] * rhs[0] - lhs[1] * rhs[1] - lhs[2] * rhs[2] - lhs[3] * rhs[3] + i = lhs[0] * rhs[1] + lhs[1] * rhs[0] + lhs[2] * rhs[3] - lhs[3] * rhs[2] + j = lhs[0] * rhs[2] + lhs[2] * rhs[0] + lhs[3] * rhs[1] - lhs[1] * rhs[3] + k = lhs[0] * rhs[3] + lhs[3] * rhs[0] + lhs[1] * rhs[2] - lhs[2] * rhs[1] + + len2 = r * r + i * i + j * j + k * k + ilen = Select(self.sim, len2 - 1.0 < 1E-8, 1.0, 1.0 / Sqrt(self.sim, len2)) + return Quaternion(self.sim, [r * ilen, i * ilen, j * ilen, k * ilen]) + + def keyword_quaternion_to_rotation_matrix(self, args): + assert len(args) == 1, "quaternion_to_rotation_matrix() keyword requires one parameter!" + quat = args[0] + assert quat.type() == Types.Quaternion, \ + "quaternion_to_rotation_matrix(): Given argument is not a quaternion!" + + return Matrix(self.sim, [ 1.0 - 2.0 * quat[2] * quat[2] - 2.0 * quat[3] * quat[3], + 2.0 * (quat[1] * quat[2] - quat[0] * quat[3]), + 2.0 * (quat[1] * quat[3] + quat[0] * quat[2]), + 2.0 * (quat[1] * quat[2] + quat[0] * quat[3]), + 1.0 - 2.0 * quat[1] * quat[1] - 2.0 * quat[3] * quat[3], + 2.0 * (quat[2] * quat[3] - quat[0] * quat[1]), + 2.0 * (quat[1] * quat[3] - quat[0] * quat[2]), + 2.0 * (quat[2] * quat[3] + quat[0] * quat[1]), + 1.0 - 2.0 * quat[1] * quat[1] - 2.0 * quat[2] * quat[2] ]) diff --git a/src/pairs/sim/comm.py b/src/pairs/sim/comm.py index bb4b2612009d12276cf1f938add21fc68c544a4b..d945a404d335ea0b0294b025eb74258290328845 100644 --- a/src/pairs/sim/comm.py +++ b/src/pairs/sim/comm.py @@ -36,7 +36,9 @@ class Comm: @pairs_inline def synchronize(self): - prop_list = [self.sim.property(p) for p in ['position']] + # Every property that is not constant across timesteps and have neighbor accesses during any + # interaction kernel (i.e. property[j] in force calculation kernel) + prop_list = [self.sim.property(p) for p in ['position', 'linear_velocity', 'angular_velocity']] for step in range(self.dom_part.number_of_steps()): PackGhostParticles(self, step, prop_list) CommunicateData(self, step, prop_list) @@ -44,7 +46,9 @@ class Comm: @pairs_inline def borders(self): - prop_list = [self.sim.property(p) for p in ['mass', 'position', 'flags']] + # Every property that is constant across timesteps and have neighbor accesses during any + # interaction kernel (i.e. property[j] in force calculation kernel) + prop_list = [self.sim.property(p) for p in ['mass', 'position', 'linear_velocity', 'angular_velocity', 'flags']] Assign(self.sim, self.nsend_all, 0) Assign(self.sim, self.sim.nghost, 0) @@ -59,6 +63,7 @@ class Comm: @pairs_inline def exchange(self): + # Every property except volatiles prop_list = [self.sim.property(p) for p in ['mass', 'position', 'linear_velocity', 'shape', 'flags']] for step in range(self.dom_part.number_of_steps()): Assign(self.sim, self.nsend_all, 0) @@ -104,7 +109,7 @@ class CommunicateData(Lowerable): @pairs_inline def lower(self): - elem_size = sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) + elem_size = sum([Types.number_of_elements(self.sim, p.type()) for p in self.prop_list]) Call_Void(self.sim, "pairs->communicateData", [self.step, elem_size, self.comm.send_buffer, self.comm.send_offsets, self.comm.nsend, self.comm.recv_buffer, self.comm.recv_offsets, self.comm.nrecv]) @@ -192,7 +197,7 @@ class PackGhostParticles(Lowerable): self.sim.add_statement(self) def get_elems_per_particle(self): - return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) + return sum([Types.number_of_elements(self.sim, p.type()) for p in self.prop_list]) #@pairs_host_block @pairs_device_block @@ -209,15 +214,16 @@ class PackGhostParticles(Lowerable): p_offset = 0 m = send_map[i] for p in self.prop_list: - if p.type() == Types.Vector: - for d in range(self.sim.ndims()): - src = p[m][d] + if not Types.is_scalar(p.type()): + nelems = Types.number_of_elements(self.sim, p.type()) + for e in range(nelems): + src = p[m][e] if p == self.sim.position(): - src += send_mult[i][d] * self.sim.grid.length(d) + src += send_mult[i][e] * self.sim.grid.length(e) - Assign(self.sim, send_buffer[i][p_offset + d], src) + Assign(self.sim, send_buffer[i][p_offset + e], src) - p_offset += self.sim.ndims() + p_offset += nelems else: cast_fn = lambda x: Cast(self.sim, x, Types.Double) if p.type() != Types.Double else x @@ -234,7 +240,7 @@ class UnpackGhostParticles(Lowerable): self.sim.add_statement(self) def get_elems_per_particle(self): - return sum([self.sim.ndims() if p.type() == Types.Vector else 1 for p in self.prop_list]) + return sum([Types.number_of_elements(self.sim, p.type()) for p in self.prop_list]) #@pairs_host_block @pairs_device_block @@ -249,11 +255,12 @@ class UnpackGhostParticles(Lowerable): for i in For(self.sim, start, ScalarOp.inline(start + sum([self.comm.nrecv[j] for j in step_indexes]))): p_offset = 0 for p in self.prop_list: - if p.type() == Types.Vector: - for d in range(self.sim.ndims()): - Assign(self.sim, p[nlocal + i][d], recv_buffer[i][p_offset + d]) + if not Types.is_scalar(p.type()): + nelems = Types.number_of_elements(self.sim, p.type()) + for e in range(nelems): + Assign(self.sim, p[nlocal + i][e], recv_buffer[i][p_offset + e]) - p_offset += self.sim.ndims() + p_offset += nelems else: cast_fn = lambda x: Cast(self.sim, x, p.type()) if p.type() != Types.Double else x @@ -301,9 +308,10 @@ class RemoveExchangedParticles_part2(Lowerable): for _ in Filter(self.sim, src > 0): dst = self.comm.send_map[i] for p in self.prop_list: - if p.type() == Types.Vector: - for d in range(self.sim.ndims()): - Assign(self.sim, p[dst][d], p[src][d]) + if not Types.is_scalar(p.type()): + nelems = Types.number_of_elements(self.sim, p.type()) + for e in range(nelems): + Assign(self.sim, p[dst][e], p[src][e]) else: Assign(self.sim, p[dst], p[src]) diff --git a/src/pairs/sim/properties.py b/src/pairs/sim/properties.py index fed66cc130fc11d4d538f29060da55babf06778a..775fe19a69a39c0c8278a3aa4a77228b3073e897 100644 --- a/src/pairs/sim/properties.py +++ b/src/pairs/sim/properties.py @@ -23,6 +23,10 @@ class AllocateProperties(FinalLowerable): sizes = [self.sim.particle_capacity] elif p.type() == Types.Vector: sizes = [self.sim.particle_capacity, self.sim.ndims()] + elif p.type() == Types.Matrix: + sizes = [self.sim.particle_capacity, self.sim.ndims() * self.sim.ndims()] + elif p.type() == Types.Quaternion: + sizes = [self.sim.particle_capacity, self.sim.ndims() + 1] else: raise Exception("Invalid property type!") @@ -44,6 +48,10 @@ class AllocateContactProperties(FinalLowerable): sizes = [self.sim.particle_capacity * self.sim.neighbor_capacity] elif p.type() == Types.Vector: sizes = [self.sim.particle_capacity * self.sim.neighbor_capacity, self.sim.ndims()] + elif p.type() == Types.Matrix: + sizes = [self.sim.particle_capacity * self.sim.neighbor_capacity, self.sim.ndims() * self.sim.ndims()] + elif p.type() == Types.Quaternion: + sizes = [self.sim.particle_capacity * self.sim.neighbor_capacity, self.sim.ndims() + 1] else: raise Exception("Invalid contact property type!") diff --git a/src/pairs/sim/simulation.py b/src/pairs/sim/simulation.py index 8fd8b85e87cef5a8882773f46a29ec5f7e697eb3..39ade23f01da7df11af17262602a4ae6d7846b2d 100644 --- a/src/pairs/sim/simulation.py +++ b/src/pairs/sim/simulation.py @@ -11,7 +11,7 @@ from pairs.ir.symbols import Symbol from pairs.ir.types import Types from pairs.ir.variables import Variables from pairs.graph.graphviz import ASTGraph -from pairs.mapping.funcs import compute +from pairs.mapping.funcs import compute, setup from pairs.sim.arrays import DeclareArrays from pairs.sim.cell_lists import CellLists, BuildCellLists, BuildCellListsStencil, PartitionCellLists from pairs.sim.comm import Comm @@ -58,6 +58,7 @@ class Simulation: self._capture_statements = True self._block = Block(self, []) self.setups = Block(self, []) + self.setup_functions = Block(self, []) self.functions = Block(self, []) self.module_list = [] self.kernel_list = [] @@ -201,6 +202,9 @@ class Simulation: def compute(self, func, cutoff_radius=None, symbols={}): return compute(self, func, cutoff_radius, symbols) + def setup(self, func, symbols={}): + return setup(self, func, symbols) + def init_block(self): self._block = Block(self, []) self._check_properties_resize = False @@ -219,14 +223,23 @@ class Simulation: else: raise Exception("Two sizes assigned to same capacity!") + def build_setup_module_with_statements(self): + self.setup_functions.add_statement( + Module(self, + name=self._module_name, + block=Block(self, self._block), + resizes_to_check=self._resizes_to_check, + check_properties_resize=self._check_properties_resize, + run_on_device=False)) + def build_module_with_statements(self, run_on_device=True): self.functions.add_statement( Module(self, - name=self._module_name, - block=Block(self, self._block), - resizes_to_check=self._resizes_to_check, - check_properties_resize=self._check_properties_resize, - run_on_device=run_on_device)) + name=self._module_name, + block=Block(self, self._block), + resizes_to_check=self._resizes_to_check, + check_properties_resize=self._check_properties_resize, + run_on_device=run_on_device)) def capture_statements(self, capture=True): self._capture_statements = capture @@ -299,6 +312,7 @@ class Simulation: body = Block.from_list(self, [ self.setups, + self.setup_functions, BuildCellListsStencil(self, self.cell_lists), VTKWrite(self, self.vtk_file, 0, self.vtk_frequency), timestep.as_block() diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index c5d8c20bcc14f390c6420697ed21d6e6c65d4d1f..bd85bf81c65608dd40bb5ced66a360fdad42720e 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -249,6 +249,48 @@ class AddExpressionDeclarations(Mutator): return ast_node + def mutate_Matrix(self, ast_node): + ast_node._values = [self.mutate(v) for v in ast_node._values] + matrix_id = id(ast_node) + if matrix_id not in self.declared_exprs and matrix_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(matrix_id) + + return ast_node + + def mutate_MatrixOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + if not ast_node.operator().is_unary(): + ast_node.rhs = self.mutate(ast_node.rhs) + + matrix_op_id = id(ast_node) + if matrix_op_id not in self.declared_exprs and matrix_op_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(matrix_op_id) + + return ast_node + + def mutate_Quaternion(self, ast_node): + ast_node._values = [self.mutate(v) for v in ast_node._values] + quat_id = id(ast_node) + if quat_id not in self.declared_exprs and quat_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(quat_id) + + return ast_node + + def mutate_QuaternionOp(self, ast_node): + ast_node.lhs = self.mutate(ast_node.lhs) + if not ast_node.operator().is_unary(): + ast_node.rhs = self.mutate(ast_node.rhs) + + quat_op_id = id(ast_node) + if quat_op_id not in self.declared_exprs and quat_op_id not in self.params: + self.push_decl(Decl(ast_node.sim, ast_node)) + self.declared_exprs.append(quat_op_id) + + return ast_node + def mutate_ContactPropertyAccess(self, ast_node): writing = self.writing ast_node.contact_prop = self.mutate(ast_node.contact_prop) diff --git a/src/pairs/transformations/loops.py b/src/pairs/transformations/loops.py index c7ab965dfa9fcac1c39d9f2e8bb6050029d1c921..829f6f7e11c6831303936e9062a12941ab289c2a 100644 --- a/src/pairs/transformations/loops.py +++ b/src/pairs/transformations/loops.py @@ -2,8 +2,10 @@ from pairs.ir.arrays import ArrayAccess from pairs.ir.features import FeaturePropertyAccess from pairs.ir.loops import For, While from pairs.ir.math import MathFunction +from pairs.ir.matrices import Matrix, MatrixOp from pairs.ir.mutator import Mutator from pairs.ir.properties import PropertyAccess, ContactPropertyAccess +from pairs.ir.quaternions import Quaternion, QuaternionOp from pairs.ir.scalars import ScalarOp from pairs.ir.select import Select from pairs.ir.vectors import Vector, VectorOp @@ -37,7 +39,11 @@ class LICM(Mutator): ContactPropertyAccess, FeaturePropertyAccess, MathFunction, + Matrix, + MatrixOp, PropertyAccess, + Quaternion, + QuaternionOp, ScalarOp, Select, Vector, diff --git a/src/pairs/transformations/modules.py b/src/pairs/transformations/modules.py index 87cd6c8538a1fc4c10b014636bcfde48db22d9e1..20249056728d5c76acb306ecb96ba03f8574d658 100644 --- a/src/pairs/transformations/modules.py +++ b/src/pairs/transformations/modules.py @@ -171,7 +171,8 @@ class ReplaceModulesByCalls(Mutator): if capacity == sim.particle_capacity: for p in properties.all(): new_capacity = sim.particle_capacity - sizes = [new_capacity, sim.ndims()] if p.type() == Types.Vector else [new_capacity] + sizes = [new_capacity] if Types.is_scalar(p.type()) else \ + [new_capacity, Types.number_of_elements(sim, p.type())] props_realloc += [ReallocProperty(sim, p, sizes)] resize_stmts.append(