Skip to content
Snippets Groups Projects
Commit 228a61ef authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Add first version of add_device_copies transformation

parent b78aa65c
Branches
Tags
No related merge requests found
...@@ -28,7 +28,7 @@ setuptools.setup(name='pairs', ...@@ -28,7 +28,7 @@ setuptools.setup(name='pairs',
install_requires=[], install_requires=[],
packages=['pairs'] + [f"pairs.{mod}" for mod in modules], packages=['pairs'] + [f"pairs.{mod}" for mod in modules],
package_dir={'pairs': 'src/pairs'}, package_dir={'pairs': 'src/pairs'},
package_data={'pairs.runtime': ['runtime/*.hpp']}, package_data={'pairs': ['runtime/*.hpp']},
classifiers=[ classifiers=[
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
......
...@@ -5,6 +5,7 @@ from pairs.ir.branches import Branch ...@@ -5,6 +5,7 @@ from pairs.ir.branches import Branch
from pairs.ir.cast import Cast from pairs.ir.cast import Cast
from pairs.ir.bin_op import BinOp, Decl, VectorAccess from pairs.ir.bin_op import BinOp, Decl, VectorAccess
from pairs.ir.data_types import Type_Int, Type_Float, Type_String, Type_Vector from pairs.ir.data_types import Type_Int, Type_Float, Type_String, Type_Vector
from pairs.ir.device import DeviceCopy
from pairs.ir.functions import Call from pairs.ir.functions import Call
from pairs.ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid from pairs.ir.layouts import Layout_AoS, Layout_SoA, Layout_Invalid
from pairs.ir.lit import Lit from pairs.ir.lit import Lit
...@@ -130,6 +131,9 @@ class CGen: ...@@ -130,6 +131,9 @@ class CGen:
call = self.generate_expression(ast_node) call = self.generate_expression(ast_node)
self.print(f"{call};") self.print(f"{call};")
if isinstance(ast_node, DeviceCopy):
self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
if isinstance(ast_node, For): if isinstance(ast_node, For):
iterator = self.generate_expression(ast_node.iterator) iterator = self.generate_expression(ast_node.iterator)
lower_range = None lower_range = None
......
...@@ -4,31 +4,13 @@ from pairs.ir.ast_node import ASTNode ...@@ -4,31 +4,13 @@ from pairs.ir.ast_node import ASTNode
class Block(ASTNode): class Block(ASTNode):
def __init__(self, sim, stmts): def __init__(self, sim, stmts):
super().__init__(sim) super().__init__(sim)
self.level = 0
self.variants = set() self.variants = set()
self.props_accessed = {}
self.props_to_sync = set()
if isinstance(stmts, Block): if isinstance(stmts, Block):
self.stmts = stmts.statements() self.stmts = stmts.statements()
else: else:
self.stmts = [stmts] if not isinstance(stmts, list) else stmts self.stmts = [stmts] if not isinstance(stmts, list) else stmts
def __lt__(self, other):
return self.level < other.level
def __le__(self, other):
return self.level <= other.level
def __gt__(self, other):
return self.level > other.level
def __ge__(self, other):
return self.level >= other.level
def set_level(self, level):
self.level = level
def add_statement(self, stmt): def add_statement(self, stmt):
if isinstance(stmt, list): if isinstance(stmt, list):
self.stmts = self.stmts + stmt self.stmts = self.stmts + stmt
...@@ -59,3 +41,24 @@ class Block(ASTNode): ...@@ -59,3 +41,24 @@ class Block(ASTNode):
result_block = Block.merge_blocks(result_block, block) result_block = Block.merge_blocks(result_block, block)
return result_block return result_block
class KernelBlock(Block):
def __init__(self, sim, stmts, run_on_host=False):
super().__init__(sim, stmts)
self.run_on_host = run_on_host
self.props_accessed = {}
def add_property_access(self, prop, oper):
prop_key = prop.name()
if prop_key not in self.props_accessed:
self.props_accessed[prop_key] = oper
elif oper not in self.props_accessed[prop_key]:
self.props_accessed[prop_key] += oper
def properties_to_synchronize(self):
return {p for p in self.props_accessed if self.props_accessed[p][0] == 'r'}
def writing_properties(self):
return {p for p in self.props_accessed if 'w' in self.props_accessed[p][0]}
from pairs.ir.ast_node import ASTNode
class DeviceCopy(ASTNode):
def __init__(self, sim, prop):
super().__init__(sim)
self.prop = prop
def children(self):
return [self.prop]
from pairs.ir.block import Block from pairs.ir.block import Block, KernelBlock
class KernelWrapper(): class KernelWrapper():
def __init__(self): def __init__(self, sim):
self.kernels = Block(self, []) self.sim = sim
self.kernels = Block(sim, [])
def add_kernel_block(self, block): def add_kernel_block(self, block):
self.kernels = Block.merge_blocks(self.kernels, block) self.kernels = Block.merge_blocks(self.kernels, KernelBlock(self.sim, block))
def lower(self): def lower(self):
return self.kernels return self.kernels
...@@ -21,6 +21,7 @@ from pairs.sim.setup_wrapper import SetupWrapper ...@@ -21,6 +21,7 @@ from pairs.sim.setup_wrapper import SetupWrapper
from pairs.sim.timestep import Timestep from pairs.sim.timestep import Timestep
from pairs.sim.variables import VariablesDecl from pairs.sim.variables import VariablesDecl
from pairs.sim.vtk import VTKWrite from pairs.sim.vtk import VTKWrite
from pairs.transformations.add_device_copies import add_device_copies
from pairs.transformations.prioritize_scalar_ops import prioritaze_scalar_ops from pairs.transformations.prioritize_scalar_ops import prioritaze_scalar_ops
from pairs.transformations.set_used_bin_ops import set_used_bin_ops from pairs.transformations.set_used_bin_ops import set_used_bin_ops
from pairs.transformations.simplify import simplify_expressions from pairs.transformations.simplify import simplify_expressions
...@@ -47,8 +48,8 @@ class ParticleSimulation: ...@@ -47,8 +48,8 @@ class ParticleSimulation:
self.nest = False self.nest = False
self.check_decl_usage = True self.check_decl_usage = True
self.block = Block(self, []) self.block = Block(self, [])
self.setups = SetupWrapper() self.setups = SetupWrapper(self)
self.kernels = KernelWrapper() self.kernels = KernelWrapper(self)
self.dims = dims self.dims = dims
self.ntimesteps = timesteps self.ntimesteps = timesteps
self.expr_id = 0 self.expr_id = 0
...@@ -218,6 +219,7 @@ class ParticleSimulation: ...@@ -218,6 +219,7 @@ class ParticleSimulation:
simplify_expressions(program) simplify_expressions(program)
move_loop_invariant_code(program) move_loop_invariant_code(program)
set_used_bin_ops(program) set_used_bin_ops(program)
add_device_copies(program)
# For this part on, all bin ops are generated without usage verification # For this part on, all bin ops are generated without usage verification
self.check_decl_usage = False self.check_decl_usage = False
......
...@@ -2,8 +2,8 @@ from pairs.ir.block import Block ...@@ -2,8 +2,8 @@ from pairs.ir.block import Block
class SetupWrapper(): class SetupWrapper():
def __init__(self): def __init__(self, sim):
self.setups = Block(self, []) self.setups = Block(sim, [])
def add_setup_block(self, block): def add_setup_block(self, block):
self.setups = Block.merge_blocks(self.setups, block) self.setups = Block.merge_blocks(self.setups, block)
......
from pairs.ir.block import KernelBlock
from pairs.ir.device import DeviceCopy
from pairs.ir.mutator import Mutator
from pairs.ir.visitor import Visitor
class AddAccessedProperties(Visitor):
def __init__(self, ast):
super().__init__(ast)
self.current_kernel_block = None
self.writing = False
def visit_Assign(self, ast_node):
for s in ast_node.sources():
self.visit(s)
self.writing = True
for d in ast_node.destinations():
self.visit(d)
self.writing = False
def visit_KernelBlock(self, ast_node):
self.current_kernel_block = ast_node
self.visit_children(ast_node)
def visit_PropertyAccess(self, ast_node):
if self.current_kernel_block is not None:
self.current_kernel_block.add_property_access(ast_node.prop, 'w' if self.writing else 'r')
class AddDeviceCopies(Mutator):
def __init__(self, ast):
super().__init__(ast)
self.synchronized_props = set()
self.props_to_copy = {}
def mutate_Block(self, ast_node):
new_stmts = []
stmts = [self.mutate(s) for s in ast_node.stmts]
for s in stmts:
if s is not None:
s_id = id(s)
if isinstance(s, KernelBlock) and s_id in self.props_to_copy:
for p in self.props.to_copy[s_id]:
new_stmts = new_stmts + DeviceCopy(ast_node.sim, p)
new_stmts.append(s)
ast_node.stmts = new_stmts
return ast_node
def mutate_KernelBlock(self, ast_node):
copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in synchronized_props}
self.props_to_copy[id(ast_node)] = copying_properties
self.synchronized_props.update(copying_properties)
self.synchronized_props -= ast_node.writing_properties()
def add_device_copies(ast):
add_accessed_props = AddAccessedProperties(ast)
add_accessed_props.visit()
add_copies = AddDeviceCopies(ast)
add_copies.mutate()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment