Skip to content
Snippets Groups Projects
Commit 0b2bcad8 authored by Martin Bauer's avatar Martin Bauer
Browse files

Bugfixes - float kernel generation should work now

parent cb4f5070
No related branches found
No related tags found
No related merge requests found
...@@ -313,7 +313,7 @@ def createLatticeBoltzmannAst(updateRule=None, optimizationParams={}, **kwargs): ...@@ -313,7 +313,7 @@ def createLatticeBoltzmannAst(updateRule=None, optimizationParams={}, **kwargs):
else: else:
splitGroups = () splitGroups = ()
res = createKernel(updateRule.allEquations, splitGroups=splitGroups, res = createKernel(updateRule.allEquations, splitGroups=splitGroups,
typeForSymbol='double' if optParams['doublePrecision'] else 'float', typeForSymbol='double' if optParams['doublePrecision'] else 'float32',
ghostLayers=1) ghostLayers=1)
elif optParams['target'] == 'gpu': elif optParams['target'] == 'gpu':
from pystencils.gpucuda import createCUDAKernel from pystencils.gpucuda import createCUDAKernel
...@@ -385,13 +385,14 @@ def createLatticeBoltzmannUpdateRule(lbMethod=None, optimizationParams={}, **kwa ...@@ -385,13 +385,14 @@ def createLatticeBoltzmannUpdateRule(lbMethod=None, optimizationParams={}, **kwa
else: else:
collisionRule = addEntropyCondition(collisionRule, omegaOutputField=params['omegaOutputField']) collisionRule = addEntropyCondition(collisionRule, omegaOutputField=params['omegaOutputField'])
fieldDtype = 'float64' if optimizationParams['doublePrecision'] else 'float32'
if optParams['fieldSize']: if optParams['fieldSize']:
fieldSize = [s + 2 for s in optParams['fieldSize']] + [len(stencil)] fieldSize = [s + 2 for s in optParams['fieldSize']] + [len(stencil)]
srcField = Field.createFixedSize(params['fieldName'], fieldSize, indexDimensions=1, srcField = Field.createFixedSize(params['fieldName'], fieldSize, indexDimensions=1,
layout=optParams['fieldLayout']) layout=optParams['fieldLayout'], dtype=fieldDtype)
else: else:
srcField = Field.createGeneric(params['fieldName'], spatialDimensions=lbMethod.dim, indexDimensions=1, srcField = Field.createGeneric(params['fieldName'], spatialDimensions=lbMethod.dim, indexDimensions=1,
layout=optParams['fieldLayout']) layout=optParams['fieldLayout'], dtype=fieldDtype)
dstField = srcField.newFieldWithDifferentName(params['secondFieldName']) dstField = srcField.newFieldWithDifferentName(params['secondFieldName'])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment