Skip to content
Snippets Groups Projects

Draft: All current mantle convection forms including SUPG forms, ApplyScaled kernel operation, constant Coefficients, some math helper functions

Closed Andreas Burkhart requested to merge burk/MC into main
8 files
+ 2020
249
Compare changes
  • Side-by-side
  • Inline
Files
8
@@ -301,6 +301,199 @@ class Apply(KernelType):
def member_variables(self) -> List[CppMemberVariable]:
return []
class ApplyScaled(KernelType):
def __init__(
self,
src_space: FunctionSpace,
dst_space: FunctionSpace,
type_descriptor: HOGType,
dims: List[int] = [2, 3],
):
self.name = "applyScaled"
self.src: FunctionSpaceImpl = FunctionSpaceImpl.create_impl(
src_space, "src", type_descriptor
)
self.dst: FunctionSpaceImpl = FunctionSpaceImpl.create_impl(
dst_space, "dst", type_descriptor
)
self.src_fields = [self.src]
self.dst_fields = [self.dst]
self.dims = dims
self.result_prefix = "elMatVec_"
def macro_loop(dim: int) -> str:
Macro = {2: "Face", 3: "Cell"}[dim]
macro = {2: "face", 3: "cell"}[dim]
if dim in dims:
return (
f"for ( auto& it : storage_->get{Macro}s() )\n"
f"{{\n"
f" {Macro}& {macro} = *it.second;\n"
f"\n"
f" // get hold of the actual numerical data in the functions\n"
f"{indent(self.src.pointer_retrieval(dim), INDENT)}\n"
f"{indent(self.dst.pointer_retrieval(dim), INDENT)}\n"
f" $pointer_retrieval_{dim}D\n"
f"\n"
f" // Zero out dst halos only\n"
f" //\n"
f" // This is also necessary when using update type == Add.\n"
f" // During additive comm we then skip zeroing the data on the lower-dim primitives.\n"
f"{indent(self.dst.zero_halos(dim), INDENT)}\n"
f"\n"
f" $scalar_parameter_setup_{dim}D\n"
f"\n"
f' this->timingTree_->start( "kernel" );\n'
f" $kernel_function_call_{dim}D\n"
f' this->timingTree_->stop( "kernel" );\n'
f"}}\n"
f"\n"
f"// Push result to lower-dimensional primitives\n"
f"//\n"
f'this->timingTree_->start( "post-communication" );\n'
f"// Note: We could avoid communication here by implementing the apply() also for the respective\n"
f"// lower dimensional primitives!\n"
f"{self.dst.post_communication(dim, 'level, DoFType::All ^ flag, *storage_, updateType == Replace')}\n"
f'this->timingTree_->stop( "post-communication" );'
)
else:
return 'WALBERLA_ABORT( "Not implemented." );'
def halo_update(dim: int) -> str:
if dim in dims:
if dim == 2:
return (
f"{indent(self.src.pre_communication(2), INDENT)}\n"
f" $comm_fe_functions_2D\n"
)
elif dim == 3:
return (
f" // Note that the order of communication is important, since the face -> cell communication may overwrite\n"
f" // parts of the halos that carry the macro-vertex and macro-edge unknowns.\n"
f"{indent(self.src.pre_communication(3), INDENT)}\n"
f" $comm_fe_functions_3D\n"
)
else:
raise HOGException("Dim not supported.")
else:
return 'WALBERLA_ABORT( "Not implemented." );'
self._template = Template(
f'this->startTiming( "{self.name}" );\n'
f"\n"
f"// Make sure that halos are up-to-date\n"
f'this->timingTree_->start( "pre-communication" );\n'
f"if ( this->storage_->hasGlobalCells() )\n"
f"{{\n"
f"{halo_update(3)}"
f"}}\n"
f"else\n"
f"{{\n"
f"{halo_update(2)}"
f"}}\n"
f'this->timingTree_->stop( "pre-communication" );\n'
f"\n"
f"if ( updateType == Replace )\n"
f"{{\n"
f" // We need to zero the destination array (including halos).\n"
f" // However, we must not zero out anything that is not flagged with the specified BCs.\n"
f" // Therefore, we first zero out everything that flagged, and then, later,\n"
f" // the halos of the highest dim primitives.\n"
f" dst.interpolate( walberla::numeric_cast< {self.dst.type_descriptor.pystencils_type} >( 0 ), level, flag );\n"
f"}}\n"
f"\n"
f"if ( storage_->hasGlobalCells() )\n"
f"{{\n"
f"{indent(macro_loop(3), INDENT)}\n"
f"}}\n"
f"else\n"
f"{{\n"
f"{indent(macro_loop(2), INDENT)}\n"
f"}}\n"
f"\n"
f'this->stopTiming( "{self.name}" );'
)
def kernel_operation(
self,
src_vecs: List[sp.MatrixBase],
dst_vecs: List[sp.MatrixBase],
mat: sp.MatrixBase,
rows: int,
) -> List[SympyAssignment]:
kernel_ops = mat * src_vecs[0]
tmp_symbols = sp.numbered_symbols(self.result_prefix)
kernel_op_assignments = [
SympyAssignment(tmp, kernel_op)
for tmp, kernel_op in zip(tmp_symbols, kernel_ops)
]
return kernel_op_assignments
def kernel_post_operation(
self,
geometry: ElementGeometry,
element_index: Tuple[int, int, int],
element_type: Union[FaceType, CellType],
src_vecs_accesses: List[List[Field.Access]],
dst_vecs_accesses: List[List[Field.Access]],
) -> List[ps.astnodes.Node]:
tmp_symbols = sp.numbered_symbols(self.result_prefix)
scaling = sp.Symbol("OperatorScaling")
# Add and store result to destination.
store_dst_vecs = [
SympyAssignment(a, a + scaling*s) for a, s in zip(dst_vecs_accesses[0], tmp_symbols)
]
return store_dst_vecs
def includes(self) -> Set[str]:
return (
{"hyteg/operators/Operator.hpp"} | self.src.includes() | self.dst.includes()
)
def base_classes(self) -> List[str]:
return [
f"public Operator< {self.src.func_type_string()}, {self.dst.func_type_string()} >",
]
def wrapper_methods(self) -> List[CppMethod]:
return [
CppMethod(
name=self.name,
arguments=[
CppVariable(
name=self.src.name,
type=self.src.func_type_string(),
is_const=True,
is_reference=True,
),
CppVariable(
name=self.dst.name,
type=self.dst.func_type_string(),
is_const=True,
is_reference=True,
),
CppVariable(name="OperatorScaling", type="real_t"),
CppVariable(name="level", type="uint_t"),
CppVariable(name="flag", type="DoFType"),
CppDefaultArgument(
variable=CppVariable(name="updateType", type="UpdateType"),
default_value="Replace",
),
],
return_type="void",
is_const=True,
content=self._template.template,
)
]
def member_variables(self) -> List[CppMemberVariable]:
return []
class GEMV(KernelType):
def __init__(
Loading