Skip to content
Snippets Groups Projects
Commit 02327683 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Add new block transformations and fix small issues


Signed-off-by: default avatarRafael Ravedutti <rafaelravedutti@gmail.com>
parent 1b7d44b3
Branches gpu
Tags
No related merge requests found
......@@ -31,6 +31,6 @@ psim.periodic(2.8)
psim.vtk_output("output/test")
psim.compute(lj, cutoff_radius, {'sigma6': sigma6, 'epsilon': epsilon})
psim.compute(euler, symbols={'dt': dt})
#psim.target(pairs.target_cpu())
psim.target(pairs.target_gpu())
psim.target(pairs.target_cpu())
#psim.target(pairs.target_gpu())
psim.generate()
......@@ -16,13 +16,13 @@ inline void cuda_assert(cudaError_t err, const char *file, int line) {
}
}
__host__ __device__ void *device_alloc(size_t size) {
__host__ void *device_alloc(size_t size) {
void *ptr;
CUDA_ASSERT(cudaMalloc(&ptr, size));
return ptr;
}
__host__ __device__ void *device_realloc(void *ptr, size_t size) {
__host__ void *device_realloc(void *ptr, size_t size) {
void *new_ptr;
CUDA_ASSERT(cudaFree(ptr));
CUDA_ASSERT(cudaMalloc(&new_ptr, size));
......
......@@ -47,7 +47,7 @@ size_t read_particle_data(PairsSim *ps, const char *filename, double *grid_buffe
float_ptr(n) = std::stod(in0);
} else {
std::cerr << "read_particle_data(): Invalid property type!" << std::endl;
return -1;
return 0;
}
}
......
from pairs.analysis.bin_ops import SetBinOpTerminals, SetUsedBinOps
from pairs.analysis.blocks import SetBlockVariants, SetParentBlock
from pairs.analysis.bin_ops import ResetInPlaceBinOps, SetBinOpTerminals, SetInPlaceBinOps, SetUsedBinOps
from pairs.analysis.blocks import SetBlockVariants, SetExprOwnerBlock, SetParentBlock
from pairs.analysis.devices import FetchKernelReferences
from pairs.analysis.modules import FetchModulesReferences
......@@ -11,6 +11,9 @@ class Analysis:
self._set_bin_op_terminals = SetBinOpTerminals(ast)
self._set_block_variants = SetBlockVariants(ast)
self._set_parent_block = SetParentBlock(ast)
self._set_expressions_owner_block = SetExprOwnerBlock(ast)
self._reset_in_place_bin_ops = ResetInPlaceBinOps(ast)
self._set_in_place_bin_ops = SetInPlaceBinOps(ast)
self._fetch_kernel_references = FetchKernelReferences(ast)
self._fetch_modules_references = FetchModulesReferences(ast)
......@@ -26,7 +29,13 @@ class Analysis:
def set_parent_block(self):
self._set_parent_block.visit()
def set_expressions_owner_block(self):
self._set_expressions_owner_block.visit()
return (self._set_expressions_owner_block.ownership, self._set_expressions_owner_block.expressions_to_lift)
def fetch_kernel_references(self):
self._reset_in_place_bin_ops.visit()
self._set_in_place_bin_ops.visit()
self._fetch_kernel_references.visit()
def fetch_modules_references(self):
......
from pairs.ir.bin_op import BinOp
from pairs.ir.visitor import Visitor
......@@ -60,3 +61,21 @@ class SetUsedBinOps(Visitor):
ast_node.decl.used = not self.writing
self.writing = False
self.visit_children(ast_node)
class ResetInPlaceBinOps(Visitor):
def __init__(self, ast):
super().__init__(ast)
def visit_BinOp(self, ast_node):
ast_node.in_place = True
self.visit_children(ast_node)
class SetInPlaceBinOps(Visitor):
def __init__(self, ast):
super().__init__(ast)
def visit_Decl(self, ast_node):
if isinstance(ast_node.elem, BinOp):
ast_node.elem.in_place = False
......@@ -119,3 +119,63 @@ class SetParentBlock(Visitor):
assert isinstance(ast_node, (For, While)), "Node must be a loop!"
loop_id = id(ast_node)
return self.parents[loop_id] if loop_id in self.parents else None
class SetExprOwnerBlock(Visitor):
def __init__(self, ast):
super().__init__(ast)
self.ownership = {}
self.expressions_to_lift = []
self.block_level = {}
self.block_parent = {}
self.block_stack = []
def common_parent_block(self, block1, block2):
if block1 is None:
return (block2, False)
if block2 is None:
return (block1, False)
parent_block1 = block1
parent_block2 = block2
while parent_block1 != parent_block2:
l1 = self.block_level[parent_block1]
l2 = self.block_level[parent_block2]
if l1 >= l2:
if l1 == 0:
return (parent_block1, False)
parent_block1 = self.block_parent[parent_block1]
if l2 >= l1:
if l2 == 0:
return (parent_block2, False)
parent_block2 = self.block_parent[parent_block2]
return (parent_block1, parent_block1 != block1 and parent_block1 != block2)
def set_ownership(self, ast_node):
if ast_node not in self.ownership:
self.ownership[ast_node] = None
self.ownership[ast_node], must_lift = self.common_parent_block(self.ownership[ast_node], self.block_stack[-1])
if must_lift and ast_node not in self.expressions_to_lift:
self.expressions_to_lift.append(ast_node)
def visit_Block(self, ast_node):
self.block_level[ast_node] = len(self.block_stack)
self.block_parent[ast_node] = self.block_stack[-1] if len(self.block_stack) > 0 else None
self.block_stack.append(ast_node)
self.visit_children(ast_node)
self.block_stack.pop()
def visit_BinOp(self, ast_node):
self.set_ownership(ast_node)
self.visit_children(ast_node)
def visit_PropertyAccess(self, ast_node):
self.set_ownership(ast_node)
self.visit_children(ast_node)
......@@ -33,7 +33,7 @@ class FetchKernelReferences(Visitor):
self.kernel_stack.append(ast_node)
self.visit_children(ast_node)
self.kernel_stack.pop()
ast_node.add_bin_op([b for b in self.kernel_used_bin_ops[kernel_id] if b not in self.kernel_decls[kernel_id]])
ast_node.add_bin_op([b for b in self.kernel_used_bin_ops[kernel_id] if b not in self.kernel_decls[kernel_id] and not b.in_place])
def visit_PropertyAccess(self, ast_node):
# Visit property and save current writing state
......@@ -51,8 +51,11 @@ class FetchKernelReferences(Visitor):
self.kernel_decls[k.kernel_id].append(ast_node.elem)
def visit_BinOp(self, ast_node):
for k in self.kernel_stack:
self.kernel_used_bin_ops[k.kernel_id].append(ast_node)
if ast_node.inlined is False:
for k in self.kernel_stack:
self.kernel_used_bin_ops[k.kernel_id].append(ast_node)
self.visit_children(ast_node)
def visit_Array(self, ast_node):
for k in self.kernel_stack:
......
......@@ -41,6 +41,7 @@ class BinOp(VectorExpression):
self.mem = mem
self.inlined = False
self.generated = False
self.in_place = False
self.bin_op_type = BinOp.infer_type(self.lhs, self.rhs, self.op)
self.terminals = set()
self.decl = Decl(sim, self)
......
......@@ -9,6 +9,7 @@ class Module(ASTNode):
def __init__(self, sim, name=None, block=None, resizes_to_check={}, check_properties_resize=False, run_on_device=False):
super().__init__(sim)
self._id = Module.last_module
self._name = name if name is not None else "module" + str(Module.last_module)
self._variables = {}
self._arrays = {}
......@@ -20,6 +21,10 @@ class Module(ASTNode):
sim.add_module(self)
Module.last_module += 1
@property
def module_id(self):
return self._id
@property
def name(self):
return self._name
......
......@@ -58,13 +58,12 @@ class EnforcePBC(Lowerable):
for i in ParticleFor(sim):
# TODO: VecFilter?
pos = positions[i]
for d in range(0, ndims):
for _ in Filter(sim, pos[d] < grid.min(d)):
pos[d].add(grid.length(d))
for _ in Filter(sim, positions[i][d] < grid.min(d)):
positions[i][d].add(grid.length(d))
for _ in Filter(sim, pos[d] > grid.max(d)):
pos[d].sub(grid.length(d))
for _ in Filter(sim, positions[i][d] > grid.max(d)):
positions[i][d].sub(grid.length(d))
class SetupPBC(Lowerable):
......@@ -91,12 +90,11 @@ class SetupPBC(Lowerable):
for d in range(0, ndims):
for i in For(sim, 0, nlocal + npbc):
pos = positions[i]
last_id = nlocal + npbc
grid_length = grid.length(d)
# TODO: VecFilter?
for _ in Filter(sim, pos[d] < grid.min(d) + cutneigh):
last_pos = positions[last_id]
last_pos = positions[nlocal + npbc]
pbc_map[npbc].set(i)
pbc_mult[npbc][d].set(1)
last_pos[d].set(pos[d] + grid_length)
......@@ -108,7 +106,7 @@ class SetupPBC(Lowerable):
npbc.add(1)
for _ in Filter(sim, pos[d] > grid.max(d) - cutneigh):
last_pos = positions[last_id]
last_pos = positions[nlocal + npbc]
pbc_map[npbc].set(i)
pbc_mult[npbc][d].set(-1)
last_pos[d].set(pos[d] - grid_length)
......
from pairs.analysis import Analysis
from pairs.transformations.blocks import MergeAdjacentBlocks
from pairs.transformations.blocks import LiftExprOwnerBlocks, MergeAdjacentBlocks
from pairs.transformations.devices import AddDeviceCopies, AddDeviceKernels
from pairs.transformations.expressions import ReplaceSymbols, SimplifyExpressions, PrioritizeScalarOps
from pairs.transformations.loops import LICM
......@@ -16,6 +16,7 @@ class Transformations:
self._replace_symbols = ReplaceSymbols(ast)
self._simplify_expressions = SimplifyExpressions(ast)
self._prioritize_scalar_ops = PrioritizeScalarOps(ast)
self._lift_expressions_to_owner_blocks = LiftExprOwnerBlocks(ast)
self._licm = LICM(ast)
self._dereference_write_variables = DereferenceWriteVariables(ast)
self._add_resize_logic = AddResizeLogic(ast)
......@@ -41,6 +42,11 @@ class Transformations:
self._simplify_expressions.mutate()
self._analysis.set_used_bin_ops()
def lift_expressions_to_owner_blocks(self):
ownership, expressions_to_lift = self._analysis.set_expressions_owner_block()
self._lift_expressions_to_owner_blocks.set_data(ownership, expressions_to_lift)
self._lift_expressions_to_owner_blocks.mutate()
def licm(self):
self._analysis.set_parent_block()
self._analysis.set_block_variants()
......@@ -67,6 +73,7 @@ class Transformations:
def apply_all(self):
self.lower_everything()
self.optimize_expressions()
self.lift_expressions_to_owner_blocks()
self.licm()
self.modularize()
self.add_device_copies()
......
from pairs.ir.block import Block
from pairs.ir.bin_op import Decl
from pairs.ir.mutator import Mutator
......@@ -18,3 +19,20 @@ class MergeAdjacentBlocks(Mutator):
ast_node.stmts = new_stmts
return ast_node
class LiftExprOwnerBlocks(Mutator):
def __init__(self, ast):
super().__init__(ast)
self.ownership = None
self.expressions_to_lift = None
def set_data(self, ownership, expressions_to_lift):
self.ownership = ownership
self.expressions_to_lift = expressions_to_lift
def mutate_Block(self, ast_node):
ast_node.stmts = \
[Decl(ast_node.sim, e) for e in self.ownership if self.ownership[e] == ast_node and e in self.expressions_to_lift] + \
[self.mutate(s) for s in ast_node.stmts]
return ast_node
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment