Skip to content
Snippets Groups Projects
Commit d6e4bfd8 authored by Daniel Bauer's avatar Daniel Bauer :speech_balloon:
Browse files

fix typing

parent 56d354b1
No related branches found
No related tags found
1 merge request!39Pystencils 2.0
Pipeline #72397 skipped
......@@ -24,8 +24,6 @@ import sympy as sp
import hog.ast
from hog.exception import HOGException
Assignment = Union[hog.ast.Assignment, ps.backend.ast.structural.PsAssignment]
class CseImplementation(enum.Enum):
SYMPY = 1
......@@ -36,11 +34,11 @@ class CseImplementation(enum.Enum):
def cse(
assignments: Iterable[Assignment],
assignments: Iterable[sp.codegen.ast.Assignment],
impl: CseImplementation,
tmp_symbol_prefix: str,
return_type: type = hog.ast.Assignment,
) -> List[Assignment]:
) -> List[sp.codegen.ast.Assignment]:
if impl == CseImplementation.POLYCSE:
ass_tuples: List[Tuple[str, sp.Expr]] = []
lhs_expressions: Dict[str, sp.Expr] = {}
......
......@@ -24,12 +24,12 @@ from pystencils.backend.ast.expressions import (
PsArrayInitList,
PsCast,
PsExpression,
PsLiteral,
PsLookup,
)
from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.functions import CFunction
from pystencils.backend.kernelcreation import FreezeExpressions, KernelCreationContext
from pystencils.backend.literals import PsLiteral
from pystencils.types import PsCustomType, PsStructType, PsType
from pystencils.types.quick import UInt
......@@ -183,7 +183,7 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
)
# Compute the (diagonal of) the transformation.
transform_function_call = PsDeclaration(
transform_function_call: PsAstNode = PsDeclaration(
trafo_matrix_var,
CFunction(
f"n1e1::macro{macro}::basisTransformation",
......@@ -201,7 +201,7 @@ class N1E1FunctionSpaceImpl(FunctionSpaceImpl):
element_type_var,
),
)
diag_declarations = [
diag_declarations: List[PsAstNode] = [
PsDeclaration(
var,
PsCast(
......
......@@ -17,12 +17,13 @@
from abc import ABC, abstractmethod
from string import Template
from textwrap import indent
from typing import List, Mapping, Set, Tuple, Union
from typing import cast, List, Mapping, Set, Tuple, Union
import sympy as sp
from sympy.codegen.ast import Assignment
from pystencils import AssignmentCollection
from pystencils.backend.ast import PsAstNode
from pystencils.backend.ast.expressions import (
PsArrayInitList,
PsCast,
......@@ -32,14 +33,13 @@ from pystencils.backend.ast.expressions import (
)
from pystencils.backend.ast.structural import (
PsAssignment,
PsAstNode,
PsComment,
PsDeclaration,
PsStatement,
)
from pystencils.backend.constants import PsConstant
from pystencils.backend.functions import CFunction
from pystencils.backend.kernelcreation import FreezeExpressions, KernelCreationContext
from pystencils.backend.memory import PsConstant
from pystencils.field import Field
from pystencils.types import PsCustomType, PsStructType
from pystencils.types.quick import SInt
......@@ -151,7 +151,7 @@ class KernelType(ABC):
def _to_backend_ast(
self, ctx: KernelCreationContext, assignments: List[Assignment]
) -> List[PsAssignment]:
) -> List[PsAstNode]:
freeze = FreezeExpressions(ctx)
return freeze(AssignmentCollection(assignments)).statements
......@@ -324,19 +324,19 @@ class Assemble(KernelType):
row_idx_decl = PsDeclaration(
row_idx,
CFunction(vec_uint_t, (ctx.index_dtype,), vec_uint_t)(
CFunction(vec_uint_t.name, (ctx.index_dtype,), vec_uint_t)(
PsExpression.make(PsConstant(nr))
),
)
col_idx_decl = PsDeclaration(
col_idx,
CFunction(vec_uint_t, (ctx.index_dtype,), vec_uint_t)(
CFunction(vec_uint_t.name, (ctx.index_dtype,), vec_uint_t)(
PsExpression.make(PsConstant(nc))
),
)
mat_decl = PsDeclaration(
mat,
CFunction(vec_real_t, (ctx.index_dtype,), vec_real_t)(
CFunction(vec_real_t.name, (ctx.index_dtype,), vec_real_t)(
PsExpression.make(PsConstant(mat_size))
),
)
......@@ -347,14 +347,14 @@ class Assemble(KernelType):
row_idx_init = [
PsAssignment(
PsLookup(row_idx, f"operator[]({i})"),
PsCast(uint_t, freeze(dst_access)),
PsCast(uint_t, cast(PsExpression, freeze(dst_access))),
)
for i, dst_access in enumerate(dst)
]
col_idx_init = [
PsAssignment(
PsLookup(col_idx, f"operator[]({i})"),
PsCast(uint_t, freeze(src_access)),
PsCast(uint_t, cast(PsExpression, freeze(src_access))),
)
for i, src_access in enumerate(src)
]
......@@ -424,19 +424,24 @@ class KernelWrapperType(ABC):
@property
@abstractmethod
def kernel_type(self) -> KernelType: ...
def kernel_type(self) -> KernelType:
...
@abstractmethod
def includes(self) -> Set[str]: ...
def includes(self) -> Set[str]:
...
@abstractmethod
def base_classes(self) -> List[str]: ...
def base_classes(self) -> List[str]:
...
@abstractmethod
def wrapper_methods(self) -> List[CppMethod]: ...
def wrapper_methods(self) -> List[CppMethod]:
...
@abstractmethod
def member_variables(self) -> List[CppMemberVariable]: ...
def member_variables(self) -> List[CppMemberVariable]:
...
def substitute(self, subs: Mapping[str, object]) -> None:
self._template = Template(self._template.safe_substitute(subs))
......
......@@ -30,7 +30,8 @@ from abc import ABC, abstractmethod
import sympy as sp
from typing import Dict, List, Type, Union
from pystencils.backend.ast.expressions import PsConstant, PsExpression
from pystencils.backend.constants import PsConstant
from pystencils.backend.ast.expressions import PsExpression
from pystencils.backend.ast.structural import PsBlock, PsComment, PsConditional, PsLoop
from pystencils.backend.kernelcreation import AstFactory, KernelCreationContext
from pystencils.defaults import DEFAULTS
......@@ -41,7 +42,6 @@ from hog.operator_generation.pystencils_extensions import (
loop_over_simplex,
loop_over_simplex_facet,
create_micro_element_loops,
fuse_loops_over_simplex,
)
from hog.operator_generation.indexing import (
all_element_types,
......@@ -187,12 +187,14 @@ class SAWTOOTH(LoopStrategy):
# add the loop bodies to each innermost loop
for element_type, body in loop_bodies.items():
nested_loop = element_loops[element_type].body.statements[0]
assert isinstance(nested_loop, PsLoop)
if dim == 2:
element_loops[element_type].body.statements[0].body = body
nested_loop.body = body
else:
element_loops[element_type].body.statements[0].body.statements[
0
].body = body
innermost_loop = nested_loop.body.statements[0]
assert isinstance(innermost_loop, PsLoop)
innermost_loop.body = body
# for each element type create a block of comment, pre loop stmts and spatial loop nest
block = PsBlock(
......@@ -238,6 +240,7 @@ class FUSEDROWS(LoopStrategy):
raise NotImplementedError(
"FUSEDROWS loop strategy has not been adapted to the new backend and is to be reimplemented in a more flexible manner (think arbitrary loop blocking)."
)
# element_loops = create_micro_element_loops(dim, micro_edges_per_macro_edge)
# (fused_loop, bodies) = fuse_loops_over_simplex(
# [elem_loop for elem_loop in element_loops.values()], 1, dim
......
......@@ -17,7 +17,8 @@
from typing import Dict, List, Tuple, Union
from pystencils import Field, FieldType
from pystencils.backend.ast.structural import PsAstNode, PsBlock, PsLoop
from pystencils.backend.ast import PsAstNode
from pystencils.backend.ast.structural import PsBlock, PsLoop
from pystencils.backend.kernelcreation import AstFactory, KernelCreationContext
from pystencils.defaults import DEFAULTS
from pystencils.types import UserTypeSpec
......@@ -166,46 +167,46 @@ def create_micro_element_loops(
return element_loops
def fuse_loops_over_simplex(
loops: List[PsLoop], dim_to_fuse: int, max_dim: int
) -> Tuple[PsLoop, List[PsAstNode]]:
"""Takes a list of simplex loops over max_dim dimensions and fuses them at dim_to_fuse.
E.g. for dim_to_fuse == 0: L_z(L_y(L_x_1(...))) + L_z(L_y(L_x_2(...))) = L_z(L_y([L_x_1(...), L_x_2(...)]))
"""
# fused loop will be constructed here
current_loop = loops[0]
fused_loops = {}
for d in range(max_dim, dim_to_fuse, -1):
if not isinstance(current_loop, PsLoop):
raise HOGException(f"Non-loop encountered: {current_loop}")
# reconstruct current loop
fused_loops[d] = current_loop.new_loop_with_different_body(Block([]))
# assert fusability
# ranges = [(loop.start, loop.step, loop.stop) for loop in loops]
# is_same = reduce(lambda p, q: p[0] - q[0] + p[1] - q[1] + p[2] - q[2] == 0, ranges, 0)
# if not is_same:
# raise HOGException(f"Loop ranges are not the same for dimension {d}!")
# iterate loop
current_loop = current_loop.body
# collect bodies, add to constructed loop
dim_to_fuse_loops = []
for loop in loops:
current_loop = loop
for d in range(max_dim - 1, dim_to_fuse, -1):
current_loop = current_loop.body
dim_to_fuse_loops.append(current_loop.body)
offset = 0 if max_dim == 2 else 1
fused_loops[max_dim - offset].body = Block(dim_to_fuse_loops)
for d in range(max_dim, dim_to_fuse + 1, -1):
fused_loops[d] = fused_loops[d].new_loop_with_different_body(fused_loops[d - 1])
return (fused_loops[max_dim], [loop.body for loop in dim_to_fuse_loops])
# def fuse_loops_over_simplex(
# loops: List[PsLoop], dim_to_fuse: int, max_dim: int
# ) -> Tuple[PsLoop, List[PsAstNode]]:
# """Takes a list of simplex loops over max_dim dimensions and fuses them at dim_to_fuse.
# E.g. for dim_to_fuse == 0: L_z(L_y(L_x_1(...))) + L_z(L_y(L_x_2(...))) = L_z(L_y([L_x_1(...), L_x_2(...)]))
# """
# # fused loop will be constructed here
# current_loop = loops[0]
# fused_loops = {}
# for d in range(max_dim, dim_to_fuse, -1):
# if not isinstance(current_loop, PsLoop):
# raise HOGException(f"Non-loop encountered: {current_loop}")
# # reconstruct current loop
# fused_loops[d] = current_loop.new_loop_with_different_body(Block([]))
# # assert fusability
# # ranges = [(loop.start, loop.step, loop.stop) for loop in loops]
# # is_same = reduce(lambda p, q: p[0] - q[0] + p[1] - q[1] + p[2] - q[2] == 0, ranges, 0)
# # if not is_same:
# # raise HOGException(f"Loop ranges are not the same for dimension {d}!")
# # iterate loop
# current_loop = current_loop.body
# # collect bodies, add to constructed loop
# dim_to_fuse_loops = []
# for loop in loops:
# current_loop = loop
# for d in range(max_dim - 1, dim_to_fuse, -1):
# current_loop = current_loop.body
# dim_to_fuse_loops.append(current_loop.body)
# offset = 0 if max_dim == 2 else 1
# fused_loops[max_dim - offset].body = Block(dim_to_fuse_loops)
# for d in range(max_dim, dim_to_fuse + 1, -1):
# fused_loops[d] = fused_loops[d].new_loop_with_different_body(fused_loops[d - 1])
# return (fused_loops[max_dim], [loop.body for loop in dim_to_fuse_loops])
def create_generic_fields(names: List[str], dtype: UserTypeSpec) -> List[Field]:
......
......@@ -16,12 +16,13 @@
import logging
import sympy as sp
from typing import Iterable, List, Optional
from typing import cast, Iterable, List, Optional
from pystencils import AssignmentCollection, TypedSymbol
from pystencils.backend.ast import PsAstNode
from pystencils.backend.ast.expressions import PsArrayInitList, PsConstant, PsExpression
from pystencils.backend.ast.structural import PsDeclaration
from pystencils.backend.ast.expressions import PsArrayInitList, PsExpression
from pystencils.backend.ast.structural import PsBlock, PsDeclaration
from pystencils.backend.constants import PsConstant
from pystencils.backend.kernelcreation import AstFactory, KernelCreationContext
from pystencils.types import PsArrayType
......@@ -168,9 +169,15 @@ class QuadLoop:
ctx.get_symbol(self.q_ctr.name, ctx.index_dtype)
block = ast_factory.parse_sympy(AssignmentCollection(accumulator_declarations))
loop_body = ast_factory.parse_sympy(
AssignmentCollection(quadrature_assignments + accumulator_updates)
block = cast(
PsBlock,
ast_factory.parse_sympy(AssignmentCollection(accumulator_declarations)),
)
loop_body = cast(
PsBlock,
ast_factory.parse_sympy(
AssignmentCollection(quadrature_assignments + accumulator_updates)
),
)
loop = ast_factory.loop(
self.q_ctr.name, slice(len(self.quadrature.weights())), loop_body
......@@ -179,7 +186,7 @@ class QuadLoop:
return block.statements
def point_weight_decls(self, ctx: KernelCreationContext) -> List[PsAstNode]:
def point_weight_decls(self, ctx: KernelCreationContext) -> List[PsDeclaration]:
"""Returns statements that declare the quadrature rules' points and weights as c arrays."""
quad_decls = []
quad_decls.append(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment