diff --git a/examples/lj.py b/examples/lj.py index 5d4a5d5010e784702e5520b6951946482f15efb5..c7392daff3d0ecfbe37b9b1aa4d31aa4a13c8974 100644 --- a/examples/lj.py +++ b/examples/lj.py @@ -18,7 +18,6 @@ target = sys.argv[1] if len(sys.argv[1]) > 1 else "none" if target != 'cpu' and target != 'gpu': print(f"Invalid target, use {cmd} <cpu/gpu>") - dt = 0.005 cutoff_radius = 2.5 skin = 0.3 diff --git a/src/pairs/code_gen/cgen.py b/src/pairs/code_gen/cgen.py index 4e7c43123a4cc264c54d261abd24af9d63a50518..7fd9480e77302e4b37e57fb9180f6286451a32cb 100644 --- a/src/pairs/code_gen/cgen.py +++ b/src/pairs/code_gen/cgen.py @@ -6,7 +6,7 @@ from pairs.ir.branches import Branch from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts from pairs.ir.bin_op import BinOp, Decl, VectorAccess -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.features import FeatureProperty, FeaturePropertyAccess, RegisterFeatureProperty from pairs.ir.functions import Call from pairs.ir.kernel import Kernel, KernelLaunch @@ -635,6 +635,10 @@ class CGen: var = self.generate_expression(ast_node.var) return f"(*{var})" + if isinstance(ast_node, DeviceStaticRef): + elem = self.generate_expression(ast_node.elem) + return f"d_{elem}" + if isinstance(ast_node, FeatureProperty): return ast_node.name() @@ -668,7 +672,7 @@ class CGen: return f"p{ast_node.id()}" + (f"_{index}" if ast_node.is_vector_kind() else "") if isinstance(ast_node, FeaturePropertyAccess): - feature_name = self.generate_expression(ast_node.feature_prop.name()) + feature_name = self.generate_expression(ast_node.feature_prop) if mem or ast_node.inlined is True: index = self.generate_expression(ast_node.index) return f"{feature_name}[{index}]" diff --git a/src/pairs/ir/device.py b/src/pairs/ir/device.py index 6ba75d860ba09286a97ff9c35bfa52a87f9734bd..90a3ad6ad3271f56c158392b25ffef53ff577520 100644 --- a/src/pairs/ir/device.py +++ b/src/pairs/ir/device.py @@ -17,6 +17,18 @@ class HostRef(ASTNode): return [self.elem] +class DeviceStaticRef(ASTNode): + def __init__(self, sim, elem): + super().__init__(sim) + self.elem = elem + + def type(self): + return self.elem.type() + + def children(self): + return [self.elem] + + class CopyArray(ASTNode): def __init__(self, sim, array, ctx): super().__init__(sim) diff --git a/src/pairs/ir/mutator.py b/src/pairs/ir/mutator.py index 2e5acb91a3c9c42a61af7ca5034aec798b1c013d..d87b2f8ea529207c9a5af3900e8b706ed4479b02 100644 --- a/src/pairs/ir/mutator.py +++ b/src/pairs/ir/mutator.py @@ -83,6 +83,10 @@ class Mutator: ast_node.elem = self.mutate(ast_node.elem) return ast_node + def mutate_DeviceStaticRef(self, ast_node): + ast_node.elem = self.mutate(ast_node.elem) + return ast_node + def mutate_Filter(self, ast_node): return self.mutate_Branch(ast_node) @@ -120,7 +124,7 @@ class Mutator: return ast_node def mutate_FeaturePropertyAccess(self, ast_node): - ast_node.prop = self.mutate(ast_node.feature_prop) + ast_node.feature_prop = self.mutate(ast_node.feature_prop) ast_node.index = self.mutate(ast_node.index) ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()} return ast_node diff --git a/src/pairs/transformations/__init__.py b/src/pairs/transformations/__init__.py index e1cf7d3558b9e09ab299c6bbb1d0915d41d589a9..538d4164cff6c37dc66fa07f32cad4475911aa2f 100644 --- a/src/pairs/transformations/__init__.py +++ b/src/pairs/transformations/__init__.py @@ -1,6 +1,6 @@ from pairs.analysis import Analysis from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks -from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules +from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels, AddHostReferencesToModules, AddDeviceReferencesToModules from pairs.transformations.expressions import ReplaceSymbols, SimplifyExpressions, PrioritizeScalarOps, AddExpressionDeclarations from pairs.transformations.loops import LICM from pairs.transformations.lower import Lower @@ -74,6 +74,10 @@ class Transformations: if self._target.is_gpu(): self.apply(AddHostReferencesToModules()) + def add_device_references_to_modules(self): + if self._target.is_gpu(): + self.apply(AddDeviceReferencesToModules()) + def apply_all(self): self.lower() self.optimize_expressions() @@ -86,3 +90,4 @@ class Transformations: self.lower(True) self.add_expression_declarations() self.add_host_references_to_modules() + self.add_device_references_to_modules() diff --git a/src/pairs/transformations/devices.py b/src/pairs/transformations/devices.py index fcb118b0b9f87222eae01c82f9c09b5678be0489..836f98f1de648ca4f08fddec8b5a1abd14e19eb4 100644 --- a/src/pairs/transformations/devices.py +++ b/src/pairs/transformations/devices.py @@ -5,7 +5,7 @@ from pairs.ir.block import Block from pairs.ir.branches import Filter from pairs.ir.cast import Cast from pairs.ir.contexts import Contexts -from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, SetArrayFlag, SetPropertyFlag, HostRef +from pairs.ir.device import ClearArrayFlag, ClearPropertyFlag, CopyArray, CopyProperty, CopyVar, DeviceStaticRef, SetArrayFlag, SetPropertyFlag, HostRef from pairs.ir.kernel import Kernel, KernelLaunch from pairs.ir.lit import Lit from pairs.ir.loops import For @@ -161,3 +161,66 @@ class AddHostReferencesToModules(Mutator): return HostRef(ast_node.sim, ast_node) return ast_node + + +class AddDeviceReferencesToModules(Mutator): + def __init__(self, ast=None): + super().__init__(ast) + self.kernel_context = False + self.within_decl = False + self.add_reference = False + self.declared_objects = [] + + def must_add_reference(self, ast_node): + return id(ast_node) not in self.declared_objects and self.kernel_context and \ + (ast_node.inlined is True or self.within_decl) + + def mutate_ArrayAccess(self, ast_node): + if isinstance(ast_node.array, (DeviceStaticRef, HostRef)): + return ast_node + + _add_reference = self.add_reference + self.add_reference = ast_node.array.is_static() and self.must_add_reference(ast_node) + ast_node.array = self.mutate(ast_node.array) + self.add_reference = _add_reference + return ast_node + + def mutate_ArrayStatic(self, ast_node): + if self.add_reference: + return DeviceStaticRef(ast_node.sim, ast_node) + + return ast_node + + def mutate_FeaturePropertyAccess(self, ast_node): + _add_reference = self.add_reference + self.add_reference = self.must_add_reference(ast_node) + ast_node.feature_prop = self.mutate(ast_node.feature_prop) + self.add_reference = _add_reference + return ast_node + + def mutate_FeatureProperty(self, ast_node): + if self.add_reference: + return DeviceStaticRef(ast_node.sim, ast_node) + + return ast_node + + def mutate_DeviceStaticRef(self, ast_node): + return ast_node + + def mutate_Decl(self, ast_node): + _within_decl = self.within_decl + self.within_decl = True + ast_node.elem = self.mutate(ast_node.elem) + self.declared_objects.append(id(ast_node.elem)) + self.within_decl = _within_decl + return ast_node + + def mutate_HostRef(self, ast_node): + return ast_node + + def mutate_Kernel(self, ast_node): + _kernel_context = self.kernel_context + self.kernel_context = True + ast_node._block = self.mutate(ast_node._block) + self.kernel_context = _kernel_context + return ast_node diff --git a/src/pairs/transformations/expressions.py b/src/pairs/transformations/expressions.py index 7a366dd9e20179a8c7cda7686a9ea85f1d091ba9..15dc68843d63b9d00f3d68b426eea210622e96b0 100644 --- a/src/pairs/transformations/expressions.py +++ b/src/pairs/transformations/expressions.py @@ -193,7 +193,7 @@ class AddExpressionDeclarations(Mutator): def mutate_FeaturePropertyAccess(self, ast_node): assert self.writing is False, "Cannot change feature property!" - ast_node.prop = self.mutate(ast_node.prop) + ast_node.feature_prop = self.mutate(ast_node.feature_prop) ast_node.index = self.mutate(ast_node.index) ast_node.expressions = {i: self.mutate(e) for i, e in ast_node.expressions.items()}