Skip to content
Snippets Groups Projects
Commit c3a2d734 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

reenable quadloop

parent 4130b0a3
No related branches found
No related tags found
1 merge request!39Pystencils 2.0
Pipeline #65556 skipped
......@@ -27,11 +27,10 @@ from pystencils.backend.ast.expressions import (
PsExpression,
PsLiteral,
)
from pystencils.backend.ast.structural import PsComment, PsDeclaration
from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.functions import CFunction
from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier
from pystencils.types import PsType
from pystencils.types.basic_types import PsCustomType
from pystencils.types import PsCustomType, PsType
from pystencils.types.quick import UInt
from hog.element_geometry import ElementGeometry
......@@ -539,12 +538,17 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
symbols = [sp.Symbol(f"{name}_{i}_{i}") for i in range(n_dofs)]
diag_vars = [PsExpression.make(ctx.get_symbol(s.name)) for s in symbols]
trafo_matrix_var = PsExpression.make(
# Create a symbol with a unique name to ensure that the name is not
# changed later by a CanonicalizeSymbols pass. This allows us to use
# this symbol's name below.
trafo_matrix_symb = ctx.duplicate_symbol(
ctx.get_symbol(
name,
PsCustomType(f"Eigen::DiagonalMatrix< real_t, {n_dofs} >", const=True),
)
)
trafo_matrix_var = PsExpression.make(trafo_matrix_symb)
level_var = PsExpression.make(ctx.get_symbol("level", UInt(64)))
macro_var = PsExpression.make(
ctx.get_symbol(f"{macro}", PsCustomType(f"{Macro}&", const=True))
......@@ -577,9 +581,9 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
var,
PsCast(
ctx.default_dtype,
CFunction(f"{name}.diagonal()", (UInt(64),), real_t)(
PsExpression.make(PsConstant(i, UInt(64)))
),
CFunction(
f"{trafo_matrix_symb.name}.diagonal()", (UInt(64),), real_t
)(PsExpression.make(PsConstant(i, UInt(64)))),
),
)
for i, var in enumerate(diag_vars)
......
......@@ -38,7 +38,7 @@ from pystencils.backend.ast.structural import (
from pystencils.backend.functions import CFunction
from pystencils.field import Field
from pystencils.kernelcreation import FreezeExpressions, KernelCreationContext, Typifier
from pystencils.types.basic_types import PsCustomType
from pystencils.types import PsCustomType
from pystencils.types.quick import SInt
from hog.cpp_printing import (
......
......@@ -16,6 +16,7 @@
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
import logging
import os
from textwrap import indent
......@@ -30,17 +31,14 @@ from sympy.codegen.ast import Assignment
from pystencils import AssignmentCollection, Target, TypedSymbol
from pystencils.backend import KernelFunction
from pystencils.backend.ast import PsAstNode
from pystencils.backend.ast.structural import (
PsAssignment,
PsBlock,
PsComment,
PsLoop,
)
from pystencils.backend.ast.analysis import UndefinedSymbolsCollector
from pystencils.backend.ast.structural import PsBlock
from pystencils.backend.kernelcreation import (
FreezeExpressions,
KernelCreationContext,
Typifier,
)
from pystencils.backend.transformations import CanonicalizeSymbols
from pystencils.sympyextensions import fast_subs
from hog.cpp_printing import (
......@@ -67,7 +65,7 @@ from hog.operator_generation.function_space_impls import FunctionSpaceImpl
from hog.operator_generation.pystencils_extensions import create_generic_fields
from hog.forms import Form
from hog.ast import Operations, count_operations
from hog.ast import Operations
from hog.blending import GeometryMap
import hog.code_generation
import hog.cse
......@@ -754,14 +752,14 @@ class HyTeGElementwiseOperator:
)
if integration_info.quad_loop:
accessed_mat_entries = mat.atoms(TypedSymbol)
accessed_mat_entries &= set().union(
*[ass.undefined_symbols for ass in kernel_op_assignments]
accessed_symbols = reduce(
set.union, (ass.rhs.free_symbols for ass in kernel_op_assignments)
)
accessed_mat_entries = mat.atoms(sp.Symbol) & accessed_symbols
with TimedLogger("constructing quadrature loops"):
quad_loop = integration_info.quad_loop.construct_quad_loop(
accessed_mat_entries, self._optimizer.cse_impl()
ctx, accessed_mat_entries, self._optimizer.cse_impl()
)
else:
quad_loop = []
......@@ -784,8 +782,10 @@ class HyTeGElementwiseOperator:
)
jacobi_assignments = hog.code_generation.jacobi_matrix_assignments(
mat,
integration_info.tables + quad_loop,
mat
if integration_info.quad_loop is None
else integration_info.quad_loop.mat_integrand,
integration_info.tables,
geometry,
self.symbolizer,
affine_points=coord_symbols_for_jac_affine,
......@@ -851,7 +851,12 @@ class HyTeGElementwiseOperator:
# coefficient that "lives" on a FEM function).
dof_symbols_set: Set[DoFSymbol] = {
a
for ass in kernel_op_assignments + quad_loop
for ass in kernel_op_assignments
+ (
[]
if integration_info.quad_loop is None
else [integration_info.quad_loop.mat_integrand]
)
for a in ass.atoms(DoFSymbol)
}
dof_symbols = sorted(dof_symbols_set, key=lambda ds: ds.name)
......@@ -909,46 +914,47 @@ class HyTeGElementwiseOperator:
for component in range(geometry.dimensions)
]
body = (
loop_counter_custom_code_nodes
+ coords_assignments
+ load_vecs
+ quad_loop
+ kernel_op_assignments
)
if not self._optimizer[Opts.QUADLOOPS]:
# Only now we replace the quadrature points and weights - if there are any.
# We also setup sympy assignments in body
with TimedLogger(
"replacing quadrature points and weigths", logging.DEBUG
):
if not quadrature.is_exact() and not quadrature.inline_values:
subs_dict = dict(quadrature.points() + quadrature.weights())
for i, node in enumerate(body):
body[i] = fast_subs(node, subs_dict)
for i, node in enumerate(kernel_op_assignments):
kernel_op_assignments[i] = fast_subs(node, subs_dict)
# count operations
# TODO: count in post operation
# TODO: count in backend ast after optimizations
ops = Operations()
for stmt in body:
if isinstance(stmt, PsLoop):
for stmt2 in stmt.body.args:
count_operations(
stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start
)
elif isinstance(stmt, sp.codegen.Assignment) or isinstance(
stmt, PsAssignment
):
count_operations(stmt.rhs, ops)
elif isinstance(stmt, PsComment):
pass
else:
ops.unknown_ops += 1
# for stmt in body:
# if isinstance(stmt, PsLoop):
# for stmt2 in stmt.body.statements:
# count_operations(
# stmt2.rhs, ops, loop_factor=stmt.stop - stmt.start
# )
# elif isinstance(stmt, sp.codegen.Assignment) or isinstance(
# stmt, PsAssignment
# ):
# count_operations(stmt.rhs, ops)
# elif isinstance(stmt, PsComment):
# pass
# else:
# ops.unknown_ops += 1
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
loop_bodies[element_type] = typify(freeze(AssignmentCollection(body)))
loop_bodies[element_type] = typify(
freeze(
AssignmentCollection(
loop_counter_custom_code_nodes + coords_assignments + load_vecs
)
)
)
loop_bodies[element_type].statements += [stmt.clone() for stmt in quad_loop]
loop_bodies[element_type].statements += typify(
freeze(AssignmentCollection(kernel_op_assignments))
).statements
loop_bodies[element_type].statements += kernel_op_post_ast
with TimedLogger(
......@@ -1008,10 +1014,10 @@ class HyTeGElementwiseOperator:
# Add quadrature points and weights array declarations, but only those
# which are actually needed.
if integration_info.quad_loop:
q_decls = integration_info.quad_loop.point_weight_decls()
undefined = block.undefined_symbols
q_decls = integration_info.quad_loop.point_weight_decls(ctx)
undefined = UndefinedSymbolsCollector()(block)
block.statements = [
q_decl for q_decl in q_decls if q_decl.lhs in undefined
q_decl for q_decl in q_decls if q_decl.lhs.symbol in undefined
] + block.statements
return (block, ops.to_table())
......@@ -1054,6 +1060,11 @@ class HyTeGElementwiseOperator:
ctx, dim, integration_info, loop_strategy, kernel_type
)
# TODO: remove unused symbols automatically
with TimedLogger("canonicalizing symbols", logging.DEBUG):
canonicalize = CanonicalizeSymbols(ctx)
kernel = canonicalize(kernel)
# optimizer applies optimizations
with TimedLogger(
f"Optimizing kernel: {kernel_type.name} in {dim}D", logging.INFO
......
......@@ -40,7 +40,6 @@ from hog.logger import TimedLogger
from hog.operator_generation.loop_strategies import (
LoopStrategy,
CUBES,
SAWTOOTH,
FUSEDROWS,
)
from hog.operator_generation.pystencils_extensions import (
......@@ -189,8 +188,6 @@ class Optimizer:
simplify_conditionals(loop, loop_counter_simplification=True)
if self[Opts.MOVECONSTANTS]:
# TODO: Symbols must be canonicalized before hoisting invariants.
with TimedLogger("moving constants out of loop", logging.DEBUG):
hoist_invariants = HoistLoopInvariantDeclarations(ctx)
kernel.statements = hoist_invariants(kernel).statements
......
......@@ -18,14 +18,22 @@ import logging
import sympy as sp
from typing import Iterable, List, Optional
import pystencils as ps
from pystencils import AssignmentCollection, TypedSymbol
from pystencils.backend.ast import PsAstNode
from pystencils.backend.ast.expressions import PsArrayInitList, PsConstant, PsExpression
from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.kernelcreation import (
AstFactory,
FreezeExpressions,
KernelCreationContext,
Typifier,
)
from pystencils.types import PsArrayType
from .quadrature import Quadrature
import hog
from hog.cse import CseImplementation
from hog.logger import TimedLogger
from hog.operator_generation.pystencils_extensions import create_field_access
from hog.operator_generation.types import HOGType
from hog.symbolizer import Symbolizer
from hog.sympy_extensions import fast_subs
......@@ -34,7 +42,7 @@ from hog.sympy_extensions import fast_subs
class QuadLoop:
"""Implements a quadrature scheme by an explicit loop over quadrature points."""
# q_ctr = ps.TypedSymbol("q", BasicType(int))
q_ctr = sp.Symbol("q")
def __init__(
self,
......@@ -61,11 +69,7 @@ class QuadLoop:
if self.symmetric and row > col:
continue
q_acc = ps.TypedSymbol(
f"q_acc_{row}_{col}",
BasicType(self.type_descriptor.pystencils_type),
const=False,
)
q_acc = sp.Symbol(f"q_acc_{row}_{col}")
self.mat[row, col] = q_acc
if self.symmetric:
......@@ -73,7 +77,8 @@ class QuadLoop:
def construct_quad_loop(
self,
accessed_mat_entries: Iterable[ps.TypedSymbol],
ctx: KernelCreationContext,
accessed_mat_entries: Iterable[sp.Symbol],
cse: Optional[CseImplementation] = None,
) -> List[PsAstNode]:
ref_symbols = self.symbolizer.ref_coords_as_list(
......@@ -93,29 +98,35 @@ class QuadLoop:
continue
coord_subs_dict = {
symbol: create_field_access(
self.p_array_names[dim],
self.type_descriptor.pystencils_type,
symbol: sp.Indexed(
TypedSymbol(
self.p_array_names[dim],
PsArrayType(
ctx.default_dtype, len(self.quadrature.weights())
),
),
self.q_ctr,
)
for dim, symbol in enumerate(ref_symbols)
}
weight = create_field_access(
self.w_array_name, self.type_descriptor.pystencils_type, self.q_ctr
weight = sp.Indexed(
TypedSymbol(
self.w_array_name,
PsArrayType(ctx.default_dtype, len(self.quadrature.weights())),
),
self.q_ctr,
)
integrated = weight * fast_subs(
self.mat_integrand[row, col], coord_subs_dict
)
accumulator_declarations.append(
ast.SympyAssignment(q_acc, 0.0, is_const=False)
)
accumulator_declarations.append(sp.codegen.Assignment(q_acc, 0.0))
quadrature_assignments.append(
ast.SympyAssignment(tmp_symbol, integrated)
sp.codegen.Assignment(tmp_symbol, integrated)
)
accumulator_updates.append(
ast.SympyAssignment(q_acc, q_acc + tmp_symbol, is_const=False)
sp.codegen.ast.AddAugmentedAssignment(q_acc, tmp_symbol)
)
# common subexpression elimination
......@@ -125,40 +136,61 @@ class QuadLoop:
quadrature_assignments,
cse,
"tmp_qloop",
return_type=ast.SympyAssignment,
return_type=sp.codegen.Assignment,
)
return accumulator_declarations + [
ast.ForLoop(
ast.Block(quadrature_assignments + accumulator_updates),
self.q_ctr,
0,
len(self.quadrature.weights()),
)
]
ast_factory = AstFactory(ctx)
freeze = FreezeExpressions(ctx)
typify = Typifier(ctx)
ctx.get_symbol(self.q_ctr.name, ctx.index_dtype)
def point_weight_decls(self) -> List[PsAstNode]:
block = typify(freeze(AssignmentCollection(accumulator_declarations)))
loop_body = typify(
freeze(AssignmentCollection(quadrature_assignments + accumulator_updates))
)
loop = ast_factory.loop(
self.q_ctr.name, slice(len(self.quadrature.weights())), loop_body
)
block.statements.append(loop)
return block.statements
def point_weight_decls(self, ctx: KernelCreationContext) -> List[PsAstNode]:
"""Returns statements that declare the quadrature rules' points and weights as c arrays."""
typify = Typifier(ctx)
quad_decls = []
quad_decls.append(
ast.ArrayDeclaration(
ast.FieldPointerSymbol(
self.w_array_name,
BasicType(self.type_descriptor.pystencils_type),
False,
PsDeclaration(
PsExpression.make(
ctx.get_symbol(
self.w_array_name,
PsArrayType(ctx.default_dtype, len(self.quadrature.weights())),
)
),
PsArrayInitList(
PsExpression.make(PsConstant(w))
for _, w in self.quadrature.weights()
),
*(sp.Float(w) for _, w in self.quadrature.weights()),
)
)
for dim in range(0, self.quadrature.geometry.dimensions):
quad_decls.append(
ast.ArrayDeclaration(
ast.FieldPointerSymbol(
self.p_array_names[dim],
BasicType(self.type_descriptor.pystencils_type),
False,
PsDeclaration(
PsExpression.make(
ctx.get_symbol(
self.p_array_names[dim],
PsArrayType(
ctx.default_dtype, len(self.quadrature.weights())
),
)
),
PsArrayInitList(
PsExpression.make(PsConstant(point[dim]))
for point in self.quadrature._points
),
*(sp.Float(point[dim]) for point in self.quadrature._points),
)
)
return quad_decls
return [typify(d) for d in quad_decls]
......@@ -6,7 +6,7 @@ dependencies = [
"numpy==1.24.3",
"quadpy-gpl==0.16.10",
"poly-cse-py",
"pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@fafe58ec98a8a8f9693f65e1fa8a8b7a09e142e1",
"pystencils @ git+https://i10git.cs.fau.de/pycodegen/pystencils.git@d065462537241f2a60ab834743f34224d7b2101c",
"pytest==7.3.1",
"sympy==1.11.1",
"tabulate==0.9.0",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment