Skip to content
Snippets Groups Projects
Commit b57eabce authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Remove KernelBlock and store varying properties in modules

parent 595e7577
Branches
Tags
No related merge requests found
from pairs.ir.assign import Assign from pairs.ir.assign import Assign
from pairs.ir.arrays import Array, ArrayAccess, ArrayDecl from pairs.ir.arrays import Array, ArrayAccess, ArrayDecl
from pairs.ir.block import Block, KernelBlock from pairs.ir.block import Block
from pairs.ir.branches import Branch from pairs.ir.branches import Branch
from pairs.ir.cast import Cast from pairs.ir.cast import Cast
from pairs.ir.bin_op import BinOp, Decl, VectorAccess from pairs.ir.bin_op import BinOp, Decl, VectorAccess
...@@ -12,7 +12,7 @@ from pairs.ir.lit import Lit ...@@ -12,7 +12,7 @@ from pairs.ir.lit import Lit
from pairs.ir.loops import For, Iter, ParticleFor, While from pairs.ir.loops import For, Iter, ParticleFor, While
from pairs.ir.math import Ceil, Sqrt from pairs.ir.math import Ceil, Sqrt
from pairs.ir.memory import Malloc, Realloc from pairs.ir.memory import Malloc, Realloc
from pairs.ir.module import Module_Call from pairs.ir.module import ModuleCall
from pairs.ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty from pairs.ir.properties import Property, PropertyAccess, PropertyList, RegisterProperty, UpdateProperty
from pairs.ir.select import Select from pairs.ir.select import Select
from pairs.ir.sizeof import Sizeof from pairs.ir.sizeof import Sizeof
...@@ -176,11 +176,6 @@ class CGen: ...@@ -176,11 +176,6 @@ class CGen:
if isinstance(ast_node, DeviceCopy): if isinstance(ast_node, DeviceCopy):
self.print(f"pairs::copy_to_device({ast_node.prop.name()})") self.print(f"pairs::copy_to_device({ast_node.prop.name()})")
if isinstance(ast_node, KernelBlock):
self.print.add_indent(-4)
self.generate_statement(ast_node.block)
self.print.add_indent(4) # Workaround for fixing indentation of kernels
if isinstance(ast_node, For): if isinstance(ast_node, For):
iterator = self.generate_expression(ast_node.iterator) iterator = self.generate_expression(ast_node.iterator)
lower_range = None lower_range = None
...@@ -210,7 +205,7 @@ class CGen: ...@@ -210,7 +205,7 @@ class CGen:
else: else:
self.print(f"{array_name} = ({tkw} *) malloc({size});") self.print(f"{array_name} = ({tkw} *) malloc({size});")
if isinstance(ast_node, Module_Call): if isinstance(ast_node, ModuleCall):
module = ast_node.module module = ast_node.module
module_params = "" module_params = ""
for var in module.read_only_variables(): for var in module.read_only_variables():
......
...@@ -19,20 +19,10 @@ def pairs_device_block(func): ...@@ -19,20 +19,10 @@ def pairs_device_block(func):
func(*args, **kwargs) func(*args, **kwargs)
return Module(sim, return Module(sim,
name=sim._module_name, name=sim._module_name,
block=KernelBlock(sim, sim._block), block=Block(sim, sim._block),
resizes_to_check=sim._resizes_to_check, resizes_to_check=sim._resizes_to_check,
check_properties_resize=sim._check_properties_resize) check_properties_resize=sim._check_properties_resize,
run_on_device=True)
return inner
# TODO: Is this really useful? Or just pairs_block is enough?
def pairs_host_block(func):
def inner(*args, **kwargs):
sim = args[0].sim # self.sim
sim.init_block()
func(*args, **kwargs)
return KernelBlock(sim, sim._block, run_on_host=True)
return inner return inner
...@@ -68,8 +58,6 @@ class Block(ASTNode): ...@@ -68,8 +58,6 @@ class Block(ASTNode):
def merge_blocks(block1, block2): def merge_blocks(block1, block2):
assert isinstance(block1, Block), "First block type is not Block!" assert isinstance(block1, Block), "First block type is not Block!"
assert isinstance(block2, Block), "Second block type is not Block!" assert isinstance(block2, Block), "Second block type is not Block!"
assert not isinstance(block1, KernelBlock), "Kernel blocks cannot be merged!"
assert not isinstance(block2, KernelBlock), "Kernel blocks cannot be merged!"
return Block(block1.sim, block1.statements() + block2.statements()) return Block(block1.sim, block1.statements() + block2.statements())
def from_list(sim, block_list): def from_list(sim, block_list):
...@@ -83,28 +71,3 @@ class Block(ASTNode): ...@@ -83,28 +71,3 @@ class Block(ASTNode):
result_block.add_statement(block) result_block.add_statement(block)
return result_block return result_block
class KernelBlock(ASTNode):
def __init__(self, sim, block, run_on_host=False):
super().__init__(sim)
self.block = block if isinstance(block, Block) else Block(sim, block)
self.run_on_host = run_on_host
self.props_accessed = {}
def add_property_access(self, prop, oper):
prop_key = prop.name()
if prop_key not in self.props_accessed:
self.props_accessed[prop_key] = oper
elif oper not in self.props_accessed[prop_key]:
self.props_accessed[prop_key] += oper
def children(self):
return [self.block]
def properties_to_synchronize(self):
return {p for p in self.props_accessed if self.props_accessed[p][0] == 'r'}
def writing_properties(self):
return {p for p in self.props_accessed if 'w' in self.props_accessed[p][0]}
...@@ -7,15 +7,16 @@ from pairs.ir.variables import Var ...@@ -7,15 +7,16 @@ from pairs.ir.variables import Var
class Module(ASTNode): class Module(ASTNode):
last_module = 0 last_module = 0
def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False): def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False):
super().__init__(sim) super().__init__(sim)
self._name = name if name is not None else "module_" + str(Module.last_module) self._name = name if name is not None else "module_" + str(Module.last_module)
self._variables = {} self._variables = {}
self._arrays = set() self._arrays = {}
self._properties = set() self._properties = {}
self._block = block self._block = block
self._resizes_to_check = resizes_to_check self._resizes_to_check = resizes_to_check
self._check_properties_resize = check_properties_resize self._check_properties_resize = check_properties_resize
self._run_on_device = run_on_device
sim.add_module(self) sim.add_module(self)
Module.last_module += 1 Module.last_module += 1
...@@ -27,14 +28,18 @@ class Module(ASTNode): ...@@ -27,14 +28,18 @@ class Module(ASTNode):
def block(self): def block(self):
return self._block return self._block
@property
def run_on_device(self):
return self._run_on_device
def variables(self): def variables(self):
return self._variables return self._variables
def read_only_variables(self): def read_only_variables(self):
return [v for v in self._variables if not self._variables[v]] return [v for v in self._variables if 'w' not in self._variables[v]]
def write_variables(self): def write_variables(self):
return [v for v in self._variables if self._variables[v]] return [v for v in self._variables if 'w' in self._variables[v]]
def arrays(self): def arrays(self):
return self._arrays return self._arrays
...@@ -42,34 +47,40 @@ class Module(ASTNode): ...@@ -42,34 +47,40 @@ class Module(ASTNode):
def properties(self): def properties(self):
return self._properties return self._properties
def properties_to_synchronize(self):
return {p for p in self._properties if self._properties[p][0] == 'r'}
def write_properties(self):
return {p for p in self._properties if 'w' in self._properties[p]}
def add_array(self, array, write=False): def add_array(self, array, write=False):
array_list = array if isinstance(array, list) else [array] array_list = array if isinstance(array, list) else [array]
character = 'w' if write else 'r'
for a in array_list: for a in array_list:
assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!" assert isinstance(a, Array), "Module.add_array(): given element is not of type Array!"
self._arrays.add(a) self._arrays[a] = character if a not in self._arrays else self._arrays[a] + character
def add_variable(self, variable, write=False): def add_variable(self, variable, write=False):
variable_list = variable if isinstance(variable, list) else [variable] variable_list = variable if isinstance(variable, list) else [variable]
character = 'w' if write else 'r'
for v in variable_list: for v in variable_list:
assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!" assert isinstance(v, Var), "Module.add_variable(): given element is not of type Var!"
if v not in self._variables: self._variables[v] = character if v not in self._variables else self._variables[v] + character
self._variables[v] = write
else:
self._variables[v] = self._variables[v] or write
def add_property(self, prop, write=False): def add_property(self, prop, write=False):
prop_list = prop if isinstance(prop, list) else [prop] prop_list = prop if isinstance(prop, list) else [prop]
character = 'w' if write else 'r'
for p in prop_list: for p in prop_list:
assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!" assert isinstance(p, Property), "Module.add_property(): given element is not of type Property!"
self._properties.add(p) self._properties[p] = character if p not in self._properties else self._properties[p] + character
def children(self): def children(self):
return [self._block] return [self._block]
class Module_Call(ASTNode): class ModuleCall(ASTNode):
def __init__(self, sim, module): def __init__(self, sim, module):
assert isinstance(module, Module), "Module_Call(): given parameter is not of type Module!" assert isinstance(module, Module), "ModuleCall(): given parameter is not of type Module!"
super().__init__(sim) super().__init__(sim)
self._module = module self._module = module
......
...@@ -66,10 +66,6 @@ class Mutator: ...@@ -66,10 +66,6 @@ class Mutator:
ast_node.block = self.mutate(ast_node.block) ast_node.block = self.mutate(ast_node.block)
return ast_node return ast_node
def mutate_KernelBlock(self, ast_node):
ast_node.block = self.mutate(ast_node.block)
return ast_node
def mutate_ParticleFor(self, ast_node): def mutate_ParticleFor(self, ast_node):
return self.mutate_For(ast_node) return self.mutate_For(ast_node)
......
from pairs.ir.arrays import Arrays from pairs.ir.arrays import Arrays
from pairs.ir.block import Block, KernelBlock from pairs.ir.block import Block
from pairs.ir.branches import Filter from pairs.ir.branches import Filter
from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector from pairs.ir.data_types import Type_Int, Type_Float, Type_Vector
from pairs.ir.layouts import Layout_AoS from pairs.ir.layouts import Layout_AoS
...@@ -181,9 +181,10 @@ class Simulation: ...@@ -181,9 +181,10 @@ class Simulation:
self.kernels.add_statement( self.kernels.add_statement(
Module(self, Module(self,
name=self._module_name, name=self._module_name,
block=KernelBlock(self, self._block), block=Block(self, self._block),
resizes_to_check=self._resizes_to_check, resizes_to_check=self._resizes_to_check,
check_properties_resize=self._check_properties_resize)) check_properties_resize=self._check_properties_resize,
run_on_device=True))
def add_statement(self, stmt): def add_statement(self, stmt):
if not self.scope: if not self.scope:
......
from pairs.ir.block import KernelBlock
from pairs.ir.device import DeviceCopy from pairs.ir.device import DeviceCopy
from pairs.ir.module import ModuleCall
from pairs.ir.mutator import Mutator from pairs.ir.mutator import Mutator
from pairs.ir.visitor import Visitor from pairs.ir.visitor import Visitor
class AddAccessedProperties(Visitor):
def __init__(self, ast):
super().__init__(ast)
self.current_kernel_block = None
self.writing = False
def visit_Assign(self, ast_node):
for s in ast_node.sources():
self.visit(s)
self.writing = True
for d in ast_node.destinations():
self.visit(d)
self.writing = False
def visit_KernelBlock(self, ast_node):
self.current_kernel_block = ast_node
self.visit_children(ast_node)
def visit_PropertyAccess(self, ast_node):
if self.current_kernel_block is not None:
self.current_kernel_block.add_property_access(ast_node.prop, 'w' if self.writing else 'r')
class AddDeviceCopies(Mutator): class AddDeviceCopies(Mutator):
def __init__(self, ast): def __init__(self, ast):
super().__init__(ast) super().__init__(ast)
...@@ -41,25 +17,22 @@ class AddDeviceCopies(Mutator): ...@@ -41,25 +17,22 @@ class AddDeviceCopies(Mutator):
for s in stmts: for s in stmts:
if s is not None: if s is not None:
s_id = id(s) s_id = id(s)
if isinstance(s, KernelBlock) and s_id in self.props_to_copy: if isinstance(s, ModuleCall) and s_id in self.props_to_copy:
new_stmts = new_stmts + [DeviceCopy(ast_node.sim, ast_node.sim.property(p)) for p in self.props_to_copy[s_id]] new_stmts = new_stmts + [DeviceCopy(ast_node.sim, p) for p in self.props_to_copy[s_id]]
new_stmts.append(s) new_stmts.append(s)
ast_node.stmts = new_stmts ast_node.stmts = new_stmts
return ast_node return ast_node
def mutate_KernelBlock(self, ast_node): def mutate_ModuleCall(self, ast_node):
ast_node.block = self.mutate(ast_node.block) copying_properties = {p for p in ast_node.module.properties_to_synchronize() if p not in self.synchronized_props}
copying_properties = {p for p in ast_node.properties_to_synchronize() if p not in self.synchronized_props}
self.props_to_copy[id(ast_node)] = copying_properties self.props_to_copy[id(ast_node)] = copying_properties
self.synchronized_props.update(copying_properties) self.synchronized_props.update(copying_properties)
self.synchronized_props -= ast_node.writing_properties() self.synchronized_props -= ast_node.module.write_properties()
return ast_node return ast_node
def add_device_copies(ast): def add_device_copies(ast):
add_accessed_props = AddAccessedProperties(ast)
add_accessed_props.visit()
add_copies = AddDeviceCopies(ast) add_copies = AddDeviceCopies(ast)
add_copies.mutate() add_copies.mutate()
...@@ -7,7 +7,7 @@ from pairs.ir.data_types import Type_Vector ...@@ -7,7 +7,7 @@ from pairs.ir.data_types import Type_Vector
from pairs.ir.lit import Lit from pairs.ir.lit import Lit
from pairs.ir.loops import While from pairs.ir.loops import While
from pairs.ir.memory import Realloc from pairs.ir.memory import Realloc
from pairs.ir.module import Module, Module_Call from pairs.ir.module import Module, ModuleCall
from pairs.ir.mutator import Mutator from pairs.ir.mutator import Mutator
from pairs.ir.properties import UpdateProperty from pairs.ir.properties import UpdateProperty
from pairs.ir.variables import Var, Deref from pairs.ir.variables import Var, Deref
...@@ -38,11 +38,11 @@ class FetchModulesReferences(Visitor): ...@@ -38,11 +38,11 @@ class FetchModulesReferences(Visitor):
def visit_Array(self, ast_node): def visit_Array(self, ast_node):
for m in self.module_stack: for m in self.module_stack:
m.add_array(ast_node) m.add_array(ast_node, self.writing)
def visit_Property(self, ast_node): def visit_Property(self, ast_node):
for m in self.module_stack: for m in self.module_stack:
m.add_property(ast_node) m.add_property(ast_node, self.writing)
def visit_Var(self, ast_node): def visit_Var(self, ast_node):
for m in self.module_stack: for m in self.module_stack:
...@@ -164,7 +164,7 @@ class ReplaceModulesByCalls(Mutator): ...@@ -164,7 +164,7 @@ class ReplaceModulesByCalls(Mutator):
return ast_node return ast_node
sim = ast_node.sim sim = ast_node.sim
call = Module_Call(sim, ast_node) call = ModuleCall(sim, ast_node)
if self.module_resizes[ast_node]: if self.module_resizes[ast_node]:
properties = sim.properties properties = sim.properties
init_stmts = [] init_stmts = []
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment