Skip to content
Snippets Groups Projects
Commit 9eb8c473 authored by Martin Bauer's avatar Martin Bauer
Browse files

Extended generated waLBerla sweeps - to make them callable on cell intervals

parent 884a03c4
No related merge requests found
...@@ -4,6 +4,7 @@ from itertools import product ...@@ -4,6 +4,7 @@ from itertools import product
from typing import Dict, Sequence, Tuple, Optional from typing import Dict, Sequence, Tuple, Optional
from pystencils import create_staggered_kernel, Field, create_kernel, Assignment, FieldType from pystencils import create_staggered_kernel, Field, create_kernel, Assignment, FieldType
from pystencils.backends.cbackend import get_headers
from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets from pystencils.backends.simd_instruction_sets import get_supported_instruction_sets
from pystencils.stencils import offset_to_direction_string, inverse_direction from pystencils.stencils import offset_to_direction_string, inverse_direction
from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env from pystencils_walberla.jinja_filters import add_pystencils_filters_to_jinja_env
...@@ -44,6 +45,7 @@ def generate_sweep(generation_context, class_name, assignments, ...@@ -44,6 +45,7 @@ def generate_sweep(generation_context, class_name, assignments,
'namespace': namespace, 'namespace': namespace,
'class_name': class_name, 'class_name': class_name,
'target': create_kernel_params.get("target", "cpu"), 'target': create_kernel_params.get("target", "cpu"),
'headers': get_headers(ast),
} }
header = env.get_template("Sweep.tmpl.h").render(**jinja_context) header = env.get_template("Sweep.tmpl.h").render(**jinja_context)
source = env.get_template("Sweep.tmpl.cpp").render(**jinja_context) source = env.get_template("Sweep.tmpl.cpp").render(**jinja_context)
...@@ -58,6 +60,7 @@ def generate_sweep(generation_context, class_name, assignments, ...@@ -58,6 +60,7 @@ def generate_sweep(generation_context, class_name, assignments,
'class_name': class_name, 'class_name': class_name,
'target': create_kernel_params.get("target", "cpu"), 'target': create_kernel_params.get("target", "cpu"),
'field': representative_field, 'field': representative_field,
'headers': get_headers(ast),
} }
header = env.get_template("SweepInnerOuter.tmpl.h").render(**jinja_context) header = env.get_template("SweepInnerOuter.tmpl.h").render(**jinja_context)
source = env.get_template("SweepInnerOuter.tmpl.cpp").render(**jinja_context) source = env.get_template("SweepInnerOuter.tmpl.cpp").render(**jinja_context)
...@@ -193,6 +196,7 @@ def default_create_kernel_parameters(generation_context, params): ...@@ -193,6 +196,7 @@ def default_create_kernel_parameters(generation_context, params):
vec = params['cpu_vectorize_info'] vec = params['cpu_vectorize_info']
vec['instruction_set'] = vec.get('instruction_set', default_vec_is) vec['instruction_set'] = vec.get('instruction_set', default_vec_is)
vec['assume_inner_stride_one'] = True
vec['assume_aligned'] = vec.get('assume_aligned', False) vec['assume_aligned'] = vec.get('assume_aligned', False)
vec['nontemporal'] = vec.get('nontemporal', False) vec['nontemporal'] = vec.get('nontemporal', False)
return params return params
......
...@@ -8,7 +8,7 @@ from pystencils.kernelparameters import SHAPE_DTYPE ...@@ -8,7 +8,7 @@ from pystencils.kernelparameters import SHAPE_DTYPE
from pystencils.sympyextensions import prod from pystencils.sympyextensions import prod
temporary_fieldMemberTemplate = """ temporary_fieldMemberTemplate = """
std::set< {type} *, field::SwapableCompare< {type} * > > cache_{original_field_name}_;""" private: std::set< {type} *, field::SwapableCompare< {type} * > > cache_{original_field_name}_;"""
temporary_fieldTemplate = """ temporary_fieldTemplate = """
// Getting temporary field {tmp_field_name} // Getting temporary field {tmp_field_name}
...@@ -53,6 +53,23 @@ def get_field_fsize(field): ...@@ -53,6 +53,23 @@ def get_field_fsize(field):
else: else:
return prod(field.index_shape) return prod(field.index_shape)
def get_field_stride(param):
field = param.fields[0]
type_str = get_base_type(param.symbol.dtype).base_name
stride_names = ['xStride()', 'yStride()', 'zStride()', 'fStride()']
stride_names = ["%s(%s->%s)" % (type_str, param.field_name, e) for e in stride_names]
strides = stride_names[:field.spatial_dimensions]
if field.index_dimensions > 0:
additional_strides = [1]
for shape in reversed(field.index_shape[1:]):
additional_strides.append(additional_strides[-1] * shape)
assert len(additional_strides) == field.index_dimensions
f_stride_name = stride_names[-1]
strides.extend(["%s(%d * %s)" % (type_str, e, f_stride_name) for e in reversed(additional_strides)])
return strides[param.symbol.coordinate]
def generate_declaration(kernel_info): def generate_declaration(kernel_info):
"""Generates the declaration of the kernel function""" """Generates the declaration of the kernel function"""
ast = kernel_info.ast ast = kernel_info.ast
...@@ -222,9 +239,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non ...@@ -222,9 +239,8 @@ def generate_call(ctx, kernel_info, ghost_layers_to_include=0, cell_interval=Non
kernel_call_lines.append("%s %s = %s->dataAt(%s, %s, %s, %s);" % kernel_call_lines.append("%s %s = %s->dataAt(%s, %s, %s, %s);" %
((param.symbol.dtype, param.symbol.name, param.field_name) + coordinates)) ((param.symbol.dtype, param.symbol.name, param.field_name) + coordinates))
elif param.is_field_stride: elif param.is_field_stride:
casted_stride = get_field_stride(param)
type_str = param.symbol.dtype.base_name type_str = param.symbol.dtype.base_name
stride_names = ('xStride()', 'yStride()', 'zStride()', 'fStride()')
casted_stride = "%s(%s->%s)" % (type_str, param.field_name, stride_names[param.symbol.coordinate])
kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, casted_stride)) kernel_call_lines.append("const %s %s = %s;" % (type_str, param.symbol.name, casted_stride))
elif param.is_field_shape: elif param.is_field_shape:
coord = param.symbol.coordinate coord = param.symbol.coordinate
...@@ -327,6 +343,9 @@ def generate_members(ctx, kernel_info, parameters_to_ignore=(), only_fields=Fals ...@@ -327,6 +343,9 @@ def generate_members(ctx, kernel_info, parameters_to_ignore=(), only_fields=Fals
field_type = make_field_type(get_base_type(f.dtype), f_size, is_gpu) field_type = make_field_type(get_base_type(f.dtype), f_size, is_gpu)
result.append(temporary_fieldMemberTemplate.format(type=field_type, original_field_name=original_field_name)) result.append(temporary_fieldMemberTemplate.format(type=field_type, original_field_name=original_field_name))
if hasattr(kernel_info, 'varying_parameters'):
result.extend(["%s %s;" % e for e in kernel_info.varying_parameters])
return "\n".join(result) return "\n".join(result)
......
...@@ -23,6 +23,9 @@ ...@@ -23,6 +23,9 @@
#include "core/DataTypes.h" #include "core/DataTypes.h"
#include "core/Macros.h" #include "core/Macros.h"
#include "{{class_name}}.h" #include "{{class_name}}.h"
{% for header in headers %}
#include {{header}}
{% endfor %}
{% if target is equalto 'cpu' -%} {% if target is equalto 'cpu' -%}
...@@ -36,6 +39,7 @@ ...@@ -36,6 +39,7 @@
# pragma GCC diagnostic ignored "-Wfloat-equal" # pragma GCC diagnostic ignored "-Wfloat-equal"
# pragma GCC diagnostic ignored "-Wshadow" # pragma GCC diagnostic ignored "-Wshadow"
# pragma GCC diagnostic ignored "-Wconversion" # pragma GCC diagnostic ignored "-Wconversion"
# pragma GCC diagnostic ignored "-Wunused-variable"
#endif #endif
using namespace std; using namespace std;
...@@ -43,15 +47,37 @@ using namespace std; ...@@ -43,15 +47,37 @@ using namespace std;
namespace walberla { namespace walberla {
namespace {{namespace}} { namespace {{namespace}} {
{{kernel|generate_definition}} {{kernel|generate_definition}}
void {{class_name}}::operator() ( IBlock * block )
void {{class_name}}::sweep( IBlock * block )
{ {
{{kernel|generate_block_data_to_field_extraction|indent(4)}} {{kernel|generate_block_data_to_field_extraction|indent(4)}}
{{kernel|generate_call(stream='stream_')|indent(4)}} {{kernel|generate_call(stream='stream_')|indent(4)}}
{{kernel|generate_swaps|indent(4)}} {{kernel|generate_swaps|indent(4)}}
} }
void {{class_name}}::sweepOnCellInterval( const shared_ptr<StructuredBlockStorage> & blocks,
const CellInterval & globalCellInterval,
cell_idx_t ghostLayers,
IBlock * block )
{
CellInterval ci = globalCellInterval;
CellInterval blockBB = blocks->getBlockCellBB( *block);
blockBB.expand( ghostLayers );
ci.intersect( blockBB );
blocks->transformGlobalToBlockLocalCellInterval( ci, *block );
if( ci.empty() )
return;
{{kernel|generate_block_data_to_field_extraction|indent(4)}}
{{kernel|generate_call(stream='stream_', cell_interval='ci')|indent(4)}}
{{kernel|generate_swaps|indent(4)}}
}
} // namespace {{namespace}} } // namespace {{namespace}}
} // namespace walberla } // namespace walberla
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
//! \\author pystencils //! \\author pystencils
//====================================================================================================================== //======================================================================================================================
#pragma once
#include "core/DataTypes.h" #include "core/DataTypes.h"
{% if target is equalto 'cpu' -%} {% if target is equalto 'cpu' -%}
...@@ -27,7 +28,7 @@ ...@@ -27,7 +28,7 @@
#include "field/SwapableCompare.h" #include "field/SwapableCompare.h"
#include "domain_decomposition/BlockDataID.h" #include "domain_decomposition/BlockDataID.h"
#include "domain_decomposition/IBlock.h" #include "domain_decomposition/IBlock.h"
#include "domain_decomposition/StructuredBlockStorage.h"
#include <set> #include <set>
#ifdef __GNUC__ #ifdef __GNUC__
...@@ -56,9 +57,28 @@ public: ...@@ -56,9 +57,28 @@ public:
{{ kernel| generate_destructor(class_name) |indent(4) }} {{ kernel| generate_destructor(class_name) |indent(4) }}
void operator() ( IBlock * block ); void operator()(IBlock * b) { sweep(b); }
std::function<void (IBlock*)> getSweep() {
return [this](IBlock * b) { this->sweep(b); };
}
std::function<void (IBlock*)> getSweepOnCellInterval(const shared_ptr<StructuredBlockStorage> & blocks,
const CellInterval & globalCellInterval,
cell_idx_t ghostLayers=1 )
{
return [this, blocks, globalCellInterval, ghostLayers] (IBlock * b) {
this->sweepOnCellInterval(blocks, globalCellInterval, ghostLayers, b);
};
}
{{ kernel|generate_members|indent(4) }}
private: private:
{{kernel|generate_members|indent(4)}} void sweep( IBlock * block );
void sweepOnCellInterval(const shared_ptr<StructuredBlockStorage> & blocks,
const CellInterval & globalCellInterval, cell_idx_t ghostLayers, IBlock * block );
{%if target is equalto 'gpu'%} {%if target is equalto 'gpu'%}
cudaStream_t stream_; cudaStream_t stream_;
{% endif %} {% endif %}
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment