Skip to content
Snippets Groups Projects
Commit 392e8366 authored by Rafael Ravedutti's avatar Rafael Ravedutti
Browse files

Fix GPU for DEM case


Signed-off-by: default avatarRafael Ravedutti <rafaelravedutti@gmail.com>
parent 8fb42cd5
Branches
Tags
No related merge requests found
from pairs.ir.arrays import ArrayAccess from pairs.ir.arrays import ArrayAccess
from pairs.ir.branches import Branch
from pairs.ir.lit import Lit from pairs.ir.lit import Lit
from pairs.ir.loops import For from pairs.ir.loops import For
from pairs.ir.quaternions import QuaternionOp from pairs.ir.quaternions import QuaternionOp
...@@ -13,10 +14,25 @@ class MarkCandidateLoops(Visitor): ...@@ -13,10 +14,25 @@ class MarkCandidateLoops(Visitor):
super().__init__(ast) super().__init__(ast)
def visit_Module(self, ast_node): def visit_Module(self, ast_node):
for s in ast_node._block.stmts: possible_candidates = []
if s is not None: for stmt in ast_node._block.stmts:
if isinstance(s, For) and (not isinstance(s.min, Lit) or not isinstance(s.max, Lit)): if stmt is not None:
s.mark_as_kernel_candidate() if isinstance(stmt, Branch):
for branch_stmt in stmt.block_if.stmts:
if isinstance(branch_stmt, For):
possible_candidates.append(branch_stmt)
if stmt.block_else is not None:
for branch_stmt in stmt.block_else.stmts:
if isinstance(branch_stmt, For):
possible_candidates.append(branch_stmt)
if isinstance(stmt, For):
possible_candidates.append(stmt)
for stmt in possible_candidates:
if not isinstance(stmt.min, Lit) or not isinstance(stmt.max, Lit):
stmt.mark_as_kernel_candidate()
self.visit_children(ast_node) self.visit_children(ast_node)
......
import math import math
from pairs.ir.actions import Actions from pairs.ir.actions import Actions
from pairs.ir.block import Block from pairs.ir.block import Block
from pairs.ir.branches import Filter from pairs.ir.branches import Branch, Filter
from pairs.ir.cast import Cast from pairs.ir.cast import Cast
from pairs.ir.contexts import Contexts from pairs.ir.contexts import Contexts
from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef from pairs.ir.device import CopyArray, CopyContactProperty, CopyProperty, CopyVar, DeviceStaticRef, HostRef
...@@ -86,33 +86,72 @@ class AddDeviceCopies(Mutator): ...@@ -86,33 +86,72 @@ class AddDeviceCopies(Mutator):
class AddDeviceKernels(Mutator): class AddDeviceKernels(Mutator):
def __init__(self, ast=None): def __init__(self, ast=None):
super().__init__(ast) super().__init__(ast)
self._module_name = None
self._kernel_id = 0
def mutate_Module(self, ast_node): def create_kernel(self, sim, iterator, rmax, block):
ast_node._block = self.mutate(ast_node._block) kernel_name = f"{self._module_name}_kernel{self._kernel_id}"
kernel = sim.find_kernel_by_name(kernel_name)
if kernel is None:
kernel_body = Filter(sim, ScalarOp.inline(iterator < rmax.copy(True)), block)
kernel = Kernel(sim, kernel_name, kernel_body, iterator)
self._kernel_id += 1
return kernel
def mutate_Module(self, ast_node):
if ast_node.run_on_device: if ast_node.run_on_device:
self._module_name = ast_node.name
self._kernel_id = 0
new_stmts = [] new_stmts = []
kernel_id = 0 for stmt in ast_node._block.stmts:
for s in ast_node._block.stmts: if stmt is not None:
if s is not None: if isinstance(stmt, For) and stmt.is_kernel_candidate():
if isinstance(s, For) and s.is_kernel_candidate(): kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
kernel_name = f"{ast_node.name}_kernel{kernel_id}" new_stmts.append(
kernel = ast_node.sim.find_kernel_by_name(kernel_name) KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
if kernel is None:
kernel_body = Filter(ast_node.sim,
ScalarOp.inline(s.iterator < s.max.copy(True)),
s.block)
kernel = Kernel(ast_node.sim, kernel_name, kernel_body, s.iterator)
kernel_id += 1
new_stmts.append(KernelLaunch(ast_node.sim, kernel, s.iterator, s.min, s.max))
else: else:
new_stmts.append(s) if isinstance(stmt, Branch):
stmt = self.check_and_mutate_branch(stmt)
new_stmts.append(stmt)
ast_node._block.stmts = new_stmts ast_node._block.stmts = new_stmts
ast_node._block = self.mutate(ast_node._block)
return ast_node
def check_and_mutate_branch(self, ast_node):
new_stmts = []
for stmt in ast_node.block_if.stmts:
if stmt is not None:
if isinstance(stmt, For) and stmt.is_kernel_candidate():
kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
new_stmts.append(
KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
else:
new_stmts.append(stmt)
ast_node.block_if.stmts = new_stmts
if ast_node.block_else is not None:
new_stmts = []
for stmt in ast_node.block_else.stmts:
if stmt is not None:
if isinstance(stmt, For) and stmt.is_kernel_candidate():
kernel = self.create_kernel(ast_node.sim, stmt.iterator, stmt.max, stmt.block)
new_stmts.append(
KernelLaunch(ast_node.sim, kernel, stmt.iterator, stmt.min, stmt.max))
else:
new_stmts.append(stmt)
ast_node.block_else.stmts = new_stmts
return ast_node return ast_node
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment