Skip to content
Snippets Groups Projects
Commit d193fd9d authored by Philipp Suffa's avatar Philipp Suffa
Browse files

Merge branch 'CastFieldReads' into 'master'

Cast field reads also in BC

See merge request pycodegen/lbmpy!153
parents 2068f36f 29b7cce8
Branches
Tags release/1.3.2
No related merge requests found
Pipeline #57169 failed
...@@ -2,12 +2,13 @@ import numpy as np ...@@ -2,12 +2,13 @@ import numpy as np
import sympy as sp import sympy as sp
from lbmpy.advanced_streaming.indexing import BetweenTimestepsIndexing from lbmpy.advanced_streaming.indexing import BetweenTimestepsIndexing
from lbmpy.advanced_streaming.utility import is_inplace, Timestep, AccessPdfValues from lbmpy.advanced_streaming.utility import is_inplace, Timestep, AccessPdfValues
from pystencils import Field, Assignment, TypedSymbol, create_kernel from pystencils import Assignment, Field, TypedSymbol, create_kernel
from pystencils.stencil import inverse_direction from pystencils.stencil import inverse_direction
from pystencils import CreateKernelConfig, Target from pystencils import CreateKernelConfig, Target
from pystencils.boundaries import BoundaryHandling from pystencils.boundaries import BoundaryHandling
from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_object from pystencils.boundaries.createindexlist import numpy_data_type_for_boundary_object
from pystencils.backends.cbackend import CustomCodeNode from pystencils.backends.cbackend import CustomCodeNode
from pystencils.simp import add_subexpressions_for_field_reads
class LatticeBoltzmannBoundaryHandling(BoundaryHandling): class LatticeBoltzmannBoundaryHandling(BoundaryHandling):
...@@ -194,13 +195,16 @@ def create_lattice_boltzmann_boundary_kernel(pdf_field, index_field, lb_method, ...@@ -194,13 +195,16 @@ def create_lattice_boltzmann_boundary_kernel(pdf_field, index_field, lb_method,
boundary_assignments = boundary_functor(f_out, f_in, dir_symbol, inv_dir, lb_method, index_field) boundary_assignments = boundary_functor(f_out, f_in, dir_symbol, inv_dir, lb_method, index_field)
boundary_assignments = indexing.substitute_proxies(boundary_assignments) boundary_assignments = indexing.substitute_proxies(boundary_assignments)
# Code Elements inside the loop
elements = [Assignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_assignments.all_assignments
config = CreateKernelConfig(index_fields=[index_field], target=target, default_number_int="int32", config = CreateKernelConfig(index_fields=[index_field], target=target, default_number_int="int32",
skip_independence_check=True, **kernel_creation_args) skip_independence_check=True, **kernel_creation_args)
default_data_type = config.data_type.default_factory()
if pdf_field.dtype != default_data_type:
boundary_assignments = add_subexpressions_for_field_reads(boundary_assignments, data_type=default_data_type)
elements = [Assignment(dir_symbol, index_field[0]('dir'))]
elements += boundary_assignments.all_assignments
kernel = create_kernel(elements, config=config) kernel = create_kernel(elements, config=config)
# Code Elements ahead of the loop # Code Elements ahead of the loop
......
...@@ -10,7 +10,8 @@ from pystencils.sympyextensions import fast_subs ...@@ -10,7 +10,8 @@ from pystencils.sympyextensions import fast_subs
# -------------------------------------------- LBM Kernel Creation ----------------------------------------------------- # -------------------------------------------- LBM Kernel Creation -----------------------------------------------------
def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=StreamPullTwoFieldsAccessor()): def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=StreamPullTwoFieldsAccessor(),
data_type=None):
"""Replaces the pre- and post collision symbols in the collision rule by field accesses. """Replaces the pre- and post collision symbols in the collision rule by field accesses.
Args: Args:
...@@ -19,6 +20,7 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream ...@@ -19,6 +20,7 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream
dst_field: field used for writing pdf values if accessor.is_inplace this parameter is ignored dst_field: field used for writing pdf values if accessor.is_inplace this parameter is ignored
accessor: instance of PdfFieldAccessor, defining where to read and write values accessor: instance of PdfFieldAccessor, defining where to read and write values
to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor' to create e.g. a fused stream-collide kernel See 'fieldaccess.PdfFieldAccessor'
data_type: If a datatype is given the field reads are isolated and written to TypedSymbols of type data_type
Returns: Returns:
LbmCollisionRule where pre- and post collision symbols have been replaced LbmCollisionRule where pre- and post collision symbols have been replaced
...@@ -49,8 +51,9 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream ...@@ -49,8 +51,9 @@ def create_lbm_kernel(collision_rule, src_field, dst_field=None, accessor=Stream
new_split_groups.append([fast_subs(e, substitutions) for e in split_group]) new_split_groups.append([fast_subs(e, substitutions) for e in split_group])
result.simplification_hints['split_groups'] = new_split_groups result.simplification_hints['split_groups'] = new_split_groups
if accessor.is_inplace: if accessor.is_inplace or (data_type is not None and src_field.dtype != data_type):
result = add_subexpressions_for_field_reads(result, subexpressions=True, main_assignments=True) result = add_subexpressions_for_field_reads(result, subexpressions=True, main_assignments=True,
data_type=data_type)
return result return result
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment