Skip to content
Snippets Groups Projects
Commit bab19cf4 authored by Martin Bauer's avatar Martin Bauer
Browse files

Module to create waLBerla sweeps from pystencils

parent f1b61821
Branches
Tags
No related merge requests found
...@@ -17,7 +17,7 @@ class ResolvedFieldAccess(sp.Indexed): ...@@ -17,7 +17,7 @@ class ResolvedFieldAccess(sp.Indexed):
return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field)) return superClassContents + tuple(self.offsets) + (repr(self.idxCoordinateValues), hash(self.field))
def __getnewargs__(self): def __getnewargs__(self):
return self.name, self.indices[0], self.field, self.offsets, self.idxCoordinateValues return self.base, self.indices[0], self.field, self.offsets, self.idxCoordinateValues
class Node(object): class Node(object):
...@@ -96,7 +96,7 @@ class Conditional(Node): ...@@ -96,7 +96,7 @@ class Conditional(Node):
class KernelFunction(Node): class KernelFunction(Node):
class Argument: class Argument:
def __init__(self, name, dtype, kernelFunctionNode): def __init__(self, name, dtype, symbol, kernelFunctionNode):
from pystencils.transformations import symbolNameToVariableName from pystencils.transformations import symbolNameToVariableName
self.name = name self.name = name
self.dtype = dtype self.dtype = dtype
...@@ -106,6 +106,7 @@ class KernelFunction(Node): ...@@ -106,6 +106,7 @@ class KernelFunction(Node):
self.isFieldArgument = False self.isFieldArgument = False
self.fieldName = "" self.fieldName = ""
self.coordinate = None self.coordinate = None
self.symbol = symbol
if name.startswith(Field.DATA_PREFIX): if name.startswith(Field.DATA_PREFIX):
self.isFieldPtrArgument = True self.isFieldPtrArgument = True
...@@ -125,6 +126,23 @@ class KernelFunction(Node): ...@@ -125,6 +126,23 @@ class KernelFunction(Node):
fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed} fieldMap = {symbolNameToVariableName(f.name): f for f in kernelFunctionNode.fieldsAccessed}
self.field = fieldMap[self.fieldName] self.field = fieldMap[self.fieldName]
def __lt__(self, other):
def score(l):
if l.isFieldPtrArgument:
return -4
elif l.isFieldShapeArgument:
return -3
elif l.isFieldStrideArgument:
return -2
return 0
if score(self) < score(other):
return True
elif score(self) == score(other):
return self.name < other.name
else:
return False
def __repr__(self): def __repr__(self):
return '<{0} {1}>'.format(self.dtype, self.name) return '<{0} {1}>'.format(self.dtype, self.name)
...@@ -166,10 +184,9 @@ class KernelFunction(Node): ...@@ -166,10 +184,9 @@ class KernelFunction(Node):
def _updateParameters(self): def _updateParameters(self):
undefinedSymbols = self._body.undefinedSymbols - self.globalVariables undefinedSymbols = self._body.undefinedSymbols - self.globalVariables
self._parameters = [KernelFunction.Argument(s.name, s.dtype, self) for s in undefinedSymbols] self._parameters = [KernelFunction.Argument(s.name, s.dtype, s, self) for s in undefinedSymbols]
self._parameters.sort(key=lambda l: (l.fieldName, l.isFieldPtrArgument, l.isFieldShapeArgument,
l.isFieldStrideArgument, l.name), self._parameters.sort()
reverse=True)
def __str__(self): def __str__(self):
self._updateParameters() self._updateParameters()
......
...@@ -4,13 +4,13 @@ from pystencils.astnodes import Node ...@@ -4,13 +4,13 @@ from pystencils.astnodes import Node
from pystencils.types import createType, PointerType from pystencils.types import createType, PointerType
def generateC(astNode): def generateC(astNode, signatureOnly=False):
""" """
Prints the abstract syntax tree as C function Prints the abstract syntax tree as C function
""" """
fieldTypes = set([f.dtype for f in astNode.fieldsAccessed]) fieldTypes = set([f.dtype for f in astNode.fieldsAccessed])
useFloatConstants = createType("double") not in fieldTypes useFloatConstants = createType("double") not in fieldTypes
printer = CBackend(constantsAsFloats=useFloatConstants) printer = CBackend(constantsAsFloats=useFloatConstants, signatureOnly=signatureOnly)
return printer(astNode) return printer(astNode)
...@@ -51,13 +51,14 @@ class PrintNode(CustomCppCode): ...@@ -51,13 +51,14 @@ class PrintNode(CustomCppCode):
class CBackend(object): class CBackend(object):
def __init__(self, constantsAsFloats=False, sympyPrinter=None): def __init__(self, constantsAsFloats=False, sympyPrinter=None, signatureOnly=False):
if sympyPrinter is None: if sympyPrinter is None:
self.sympyPrinter = CustomSympyPrinter(constantsAsFloats) self.sympyPrinter = CustomSympyPrinter(constantsAsFloats)
else: else:
self.sympyPrinter = sympyPrinter self.sympyPrinter = sympyPrinter
self._indent = " " self._indent = " "
self._signatureOnly = signatureOnly
def __call__(self, node): def __call__(self, node):
return str(self._print(node)) return str(self._print(node))
...@@ -72,6 +73,9 @@ class CBackend(object): ...@@ -72,6 +73,9 @@ class CBackend(object):
def _print_KernelFunction(self, node): def _print_KernelFunction(self, node):
functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters] functionArguments = ["%s %s" % (str(s.dtype), s.name) for s in node.parameters]
funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments)) funcDeclaration = "FUNC_PREFIX void %s(%s)" % (node.functionName, ", ".join(functionArguments))
if self._signatureOnly:
return funcDeclaration
body = self._print(node.body) body = self._print(node.body)
return funcDeclaration + "\n" + body return funcDeclaration + "\n" + body
......
...@@ -42,7 +42,9 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}): ...@@ -42,7 +42,9 @@ def makePythonFunction(kernelFunctionNode, argumentDict={}):
shape = _checkArguments(parameters, fullArguments) shape = _checkArguments(parameters, fullArguments)
indexing = kernelFunctionNode.indexing indexing = kernelFunctionNode.indexing
dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape, func) dictWithBlockAndThreadNumbers = indexing.getCallParameters(shape)
dictWithBlockAndThreadNumbers['block'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['block'])
dictWithBlockAndThreadNumbers['grid'] = tuple(int(i) for i in dictWithBlockAndThreadNumbers['grid'])
args = _buildNumpyArgumentList(parameters, fullArguments) args = _buildNumpyArgumentList(parameters, fullArguments)
cache[key] = (args, dictWithBlockAndThreadNumbers) cache[key] = (args, dictWithBlockAndThreadNumbers)
......
...@@ -31,12 +31,10 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})): ...@@ -31,12 +31,10 @@ class AbstractIndexing(abc.ABCMeta('ABC', (object,), {})):
return BLOCK_IDX + THREAD_IDX return BLOCK_IDX + THREAD_IDX
@abc.abstractmethod @abc.abstractmethod
def getCallParameters(self, arrShape, functionToCall): def getCallParameters(self, arrShape):
""" """
Determine grid and block size for kernel call Determine grid and block size for kernel call
:param arrShape: the numeric (not symbolic) shape of the array :param arrShape: the numeric (not symbolic) shape of the array
:param functionToCall: compile kernel function that should be called. Use this object to get information
about required resources like number of registers
:return: dict with keys 'blocks' and 'threads' with tuple values for number of (x,y,z) threads and blocks :return: dict with keys 'blocks' and 'threads' with tuple values for number of (x,y,z) threads and blocks
the kernel should be started with the kernel should be started with
""" """
...@@ -87,14 +85,14 @@ class BlockIndexing(AbstractIndexing): ...@@ -87,14 +85,14 @@ class BlockIndexing(AbstractIndexing):
return coordinates[:self._dim] return coordinates[:self._dim]
def getCallParameters(self, arrShape, functionToCall): def getCallParameters(self, arrShape):
substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None} substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None}
widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice), widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice),
_getEndFromSlice(self._iterationSlice, arrShape))] _getEndFromSlice(self._iterationSlice, arrShape))]
widths = sp.Matrix(widths).subs(substitutionDict) widths = sp.Matrix(widths).subs(substitutionDict)
grid = tuple(math.ceil(length / blockSize) for length, blockSize in zip(widths, self._blockSize)) grid = tuple(sp.ceiling(length / blockSize) for length, blockSize in zip(widths, self._blockSize))
extendBs = (1,) * (3 - len(self._blockSize)) extendBs = (1,) * (3 - len(self._blockSize))
extendGr = (1,) * (3 - len(grid)) extendGr = (1,) * (3 - len(grid))
...@@ -230,7 +228,7 @@ class LineIndexing(AbstractIndexing): ...@@ -230,7 +228,7 @@ class LineIndexing(AbstractIndexing):
def coordinates(self): def coordinates(self):
return [i + offset for i, offset in zip(self._coordinates, _getStartFromSlice(self._iterationSlice))] return [i + offset for i, offset in zip(self._coordinates, _getStartFromSlice(self._iterationSlice))]
def getCallParameters(self, arrShape, functionToCall): def getCallParameters(self, arrShape):
substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None} substitutionDict = {sym: value for sym, value in zip(self._symbolicShape, arrShape) if sym is not None}
widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice), widths = [end - start for start, end in zip(_getStartFromSlice(self._iterationSlice),
...@@ -242,7 +240,7 @@ class LineIndexing(AbstractIndexing): ...@@ -242,7 +240,7 @@ class LineIndexing(AbstractIndexing):
return 1 return 1
else: else:
idx = self._coordinates.index(cudaIdx) idx = self._coordinates.index(cudaIdx)
return int(widths[idx]) return widths[idx]
return {'block': tuple([getShapeOfCudaIdx(idx) for idx in THREAD_IDX]), return {'block': tuple([getShapeOfCudaIdx(idx) for idx in THREAD_IDX]),
'grid': tuple([getShapeOfCudaIdx(idx) for idx in BLOCK_IDX])} 'grid': tuple([getShapeOfCudaIdx(idx) for idx in BLOCK_IDX])}
......
...@@ -239,6 +239,10 @@ class BasicType(Type): ...@@ -239,6 +239,10 @@ class BasicType(Type):
def is_other(self): def is_other(self):
return self.numpyDtype in np.sctypes['others'] return self.numpyDtype in np.sctypes['others']
@property
def baseName(self):
return BasicType.numpyNameToC(str(self._dtype))
def __str__(self): def __str__(self):
result = BasicType.numpyNameToC(str(self._dtype)) result = BasicType.numpyNameToC(str(self._dtype))
if self.const: if self.const:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment