Skip to content
Snippets Groups Projects
Commit 44815e98 authored by Martin Bauer's avatar Martin Bauer
Browse files

Fixes for new LBStep / Datahandling

parent 1d1a192a
Branches
Tags
No related merge requests found
import sympy as sp import sympy as sp
from functools import partial from functools import partial
from collections import defaultdict
from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction from pystencils.astnodes import SympyAssignment, Block, LoopOverCoordinate, KernelFunction
from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \ from pystencils.transformations import resolveBufferAccesses, resolveFieldAccesses, makeLoopOverDomain, \
typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \ typeAllEquations, getOptimalLoopOrdering, parseBasePointerInfo, moveConstantsBeforeLoop, splitInnerLoop, \
......
...@@ -110,7 +110,7 @@ class DataHandling(ABC): ...@@ -110,7 +110,7 @@ class DataHandling(ABC):
:param name: name of the array to gather :param name: name of the array to gather
:param sliceObj: slice expression of the rectangular sub-part that should be gathered :param sliceObj: slice expression of the rectangular sub-part that should be gathered
:param allGather: if False only the root process receives the result, if True all processes :param allGather: if False only the root process receives the result, if True all processes
:return: generator expression yielding the gathered field, the gathered field does not include any ghost layers :return: gathered field that does not include any ghost layers, or None if gathered on another process
""" """
@abstractmethod @abstractmethod
......
...@@ -157,12 +157,16 @@ class ParallelDataHandling(DataHandling): ...@@ -157,12 +157,16 @@ class ParallelDataHandling(DataHandling):
sliceObj = tuple([slice(None, None, None)] * self.dim) sliceObj = tuple([slice(None, None, None)] * self.dim)
if self.dim == 2: if self.dim == 2:
sliceObj += (0.5,) sliceObj += (0.5,)
for array in wlb.field.gatherGenerator(self.blocks, name, sliceObj, allGather):
if self.fields[name].indexDimensions == 0: array = wlb.field.gatherField(self.blocks, name, sliceObj, allGather)
array = array[..., 0] if array is None:
if self.dim == 2: return None
array = array[:, :, 0]
yield array if self.fields[name].indexDimensions == 0:
array = array[..., 0]
if self.dim == 2:
array = array[:, :, 0]
return array
def _normalizeArrShape(self, arr, indexDimensions): def _normalizeArrShape(self, arr, indexDimensions):
if indexDimensions == 0: if indexDimensions == 0:
......
...@@ -98,6 +98,7 @@ class SerialDataHandling(DataHandling): ...@@ -98,6 +98,7 @@ class SerialDataHandling(DataHandling):
indexDimensions = 0 indexDimensions = 0
layoutTuple = spatialLayoutStringToTuple(layout, self.dim) layoutTuple = spatialLayoutStringToTuple(layout, self.dim)
# cpuArr is always created - since there is no createPycudaArrayWithLayout() # cpuArr is always created - since there is no createPycudaArrayWithLayout()
cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs) cpuArr = createNumpyArrayWithLayout(layout=layoutTuple, **kwargs)
if cpu: if cpu:
...@@ -111,7 +112,7 @@ class SerialDataHandling(DataHandling): ...@@ -111,7 +112,7 @@ class SerialDataHandling(DataHandling):
assert all(f.name != latexName for f in self.fields.values()), "Symbolic field with this name already exists" assert all(f.name != latexName for f in self.fields.values()), "Symbolic field with this name already exists"
self.fields[name] = Field.createFixedSize(latexName, shape=kwargs['shape'], indexDimensions=indexDimensions, self.fields[name] = Field.createFixedSize(latexName, shape=kwargs['shape'], indexDimensions=indexDimensions,
dtype=kwargs['dtype'], layout=layout) dtype=kwargs['dtype'], layout=layoutTuple)
self._fieldLatexNameToDataName[latexName] = name self._fieldLatexNameToDataName[latexName] = name
def addCustomData(self, name, cpuCreationFunction, def addCustomData(self, name, cpuCreationFunction,
...@@ -171,7 +172,10 @@ class SerialDataHandling(DataHandling): ...@@ -171,7 +172,10 @@ class SerialDataHandling(DataHandling):
sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape) sliceObj = normalizeSlice(sliceObj, arr.shape[:-indDimensions] if indDimensions > 0 else arr.shape)
sliceObj = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in sliceObj) sliceObj = tuple(s if type(s) is slice else slice(s, s + 1, None) for s in sliceObj)
arr = arr[sliceObj] arr = arr[sliceObj]
yield arr else:
arr = arr.view()
arr.flags.writeable = False
return arr
def swap(self, name1, name2, gpu=False): def swap(self, name1, name2, gpu=False):
if not gpu: if not gpu:
...@@ -216,7 +220,6 @@ class SerialDataHandling(DataHandling): ...@@ -216,7 +220,6 @@ class SerialDataHandling(DataHandling):
return self._synchronizationFunctor(names, stencilName, 'gpu') return self._synchronizationFunctor(names, stencilName, 'gpu')
def _synchronizationFunctor(self, names, stencil, target): def _synchronizationFunctor(self, names, stencil, target):
assert target in ('cpu', 'gpu') assert target in ('cpu', 'gpu')
if not hasattr(names, '__len__') or type(names) is str: if not hasattr(names, '__len__') or type(names) is str:
names = [names] names = [names]
...@@ -224,9 +227,9 @@ class SerialDataHandling(DataHandling): ...@@ -224,9 +227,9 @@ class SerialDataHandling(DataHandling):
filteredStencil = [] filteredStencil = []
neighbors = [-1, 0, 1] neighbors = [-1, 0, 1]
if stencil.startswith('D2'): if (stencil is None and self.dim == 2) or (stencil is not None and stencil.startswith('D2')):
directions = itertools.product(*[neighbors] * 2) directions = itertools.product(*[neighbors] * 2)
elif stencil.startswith('D3'): elif (stencil is None and self.dim == 3) or (stencil is not None and stencil.startswith('D3')):
directions = itertools.product(*[neighbors] * 3) directions = itertools.product(*[neighbors] * 3)
else: else:
raise ValueError("Invalid stencil") raise ValueError("Invalid stencil")
......
import matplotlib import pystencils.plot2d as plt
import matplotlib.pyplot as plt
import matplotlib.animation as animation import matplotlib.animation as animation
from IPython.display import HTML from IPython.display import HTML
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
import base64 import base64
from IPython import get_ipython from IPython import get_ipython
...@@ -152,7 +150,7 @@ display_animation_func = None ...@@ -152,7 +150,7 @@ display_animation_func = None
def disp(*args, **kwargs): def disp(*args, **kwargs):
if not display_animation_func: if not display_animation_func:
raise "Call set_display_mode first" raise Exception("Call set_display_mode first")
return display_animation_func(*args, **kwargs) return display_animation_func(*args, **kwargs)
...@@ -179,4 +177,3 @@ ipython = get_ipython() ...@@ -179,4 +177,3 @@ ipython = get_ipython()
if ipython: if ipython:
setDisplayMode('imageupdate') setDisplayMode('imageupdate')
ipython.magic("matplotlib inline") ipython.magic("matplotlib inline")
matplotlib.rcParams['figure.figsize'] = (16.0, 6.0)
\ No newline at end of file
...@@ -3,7 +3,7 @@ from pystencils.gpucuda.indexing import indexingCreatorFromParams ...@@ -3,7 +3,7 @@ from pystencils.gpucuda.indexing import indexingCreatorFromParams
def createKernel(equations, target='cpu', dataType="double", iterationSlice=None, ghostLayers=None, def createKernel(equations, target='cpu', dataType="double", iterationSlice=None, ghostLayers=None,
cpuOpenMP=True, cpuVectorizeInfo=None, cpuOpenMP=False, cpuVectorizeInfo=None,
gpuIndexing='block', gpuIndexingParams={}): gpuIndexing='block', gpuIndexingParams={}):
""" """
Creates abstract syntax tree (AST) of kernel, using a list of update equations. Creates abstract syntax tree (AST) of kernel, using a list of update equations.
......
...@@ -73,7 +73,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la ...@@ -73,7 +73,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la
field = runFunction() field = runFunction()
if rescale: if rescale:
maxNorm = np.max(norm(field, axis=2, ord=2)) maxNorm = np.max(norm(field, axis=2, ord=2))
field /= maxNorm field = field / maxNorm
if 'scale' not in kwargs: if 'scale' not in kwargs:
kwargs['scale'] = 1.0 kwargs['scale'] = 1.0
...@@ -85,7 +85,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la ...@@ -85,7 +85,7 @@ def vectorFieldAnimation(runFunction, step=2, rescale=True, plotSetupFunction=la
f = np.swapaxes(f, 0, 1) f = np.swapaxes(f, 0, 1)
if rescale: if rescale:
maxNorm = np.max(norm(f, axis=2, ord=2)) maxNorm = np.max(norm(f, axis=2, ord=2))
f /= maxNorm f = f / maxNorm
u, v = f[::step, ::step, 0], f[::step, ::step, 1] u, v = f[::step, ::step, 0], f[::step, ::step, 1]
quiverPlot.set_UVC(u, v) quiverPlot.set_UVC(u, v)
plotUpdateFunction() plotUpdateFunction()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment