Skip to content
Snippets Groups Projects
Commit 20d37fba authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code cleanup

parent 679bf618
No related branches found
No related tags found
No related merge requests found
Pipeline #60171 failed
from functools import reduce
from typing import Any, cast
from typing import cast
from pymbolic.primitives import Variable
from pymbolic.mapper.dependency import DependencyMapper
......@@ -47,7 +46,7 @@ class UndefinedVariablesCollector:
return self.collect(lhs) | self.collect(rhs)
case PsBlock(statements):
undefined_vars = set()
undefined_vars: set[PsTypedVariable] = set()
for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt)
......
......@@ -5,6 +5,7 @@ from .nodes import PsAstNode, PsBlock, failing_cast
from ..typed_expressions import PsTypedVariable
from ...enums import Target
class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function."""
......@@ -19,11 +20,11 @@ class PsKernelFunction(PsAstNode):
def target(self) -> Target:
"""See pystencils.Target"""
return self._target
@property
def body(self) -> PsBlock:
return self._body
@body.setter
def body(self, body: PsBlock):
self._body = body
......@@ -31,16 +32,16 @@ class PsKernelFunction(PsAstNode):
@property
def name(self) -> str:
return self._name
@name.setter
def name(self, value: str):
self._name = value
def num_children(self) -> int:
return 1
def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body, )
yield from (self._body,)
def get_child(self, idx: int):
if idx not in (0, -1):
......@@ -51,14 +52,14 @@ class PsKernelFunction(PsAstNode):
if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]:
"""Collect the list of parameters to this function.
This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary.
"""
from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name)
\ No newline at end of file
......@@ -3,8 +3,6 @@ from typing import Sequence, Generator, Iterable, cast
from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
from .util import failing_cast
......@@ -41,7 +39,6 @@ class PsAstNode(ABC):
class PsBlock(PsAstNode):
__match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]):
......@@ -62,9 +59,9 @@ class PsBlock(PsAstNode):
@property
def statements(self) -> list[PsAstNode]:
return self._statements
@statements.setter
def statemetns(self, stm: Sequence[PsAstNode]):
def statements(self, stm: Sequence[PsAstNode]):
self._statements = list(stm)
......@@ -127,8 +124,10 @@ class PsSymbolExpr(PsLvalueExpr):
class PsAssignment(PsAstNode):
__match_args__ = ("lhs", "rhs",)
__match_args__ = (
"lhs",
"rhs",
)
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
self._lhs = lhs
......@@ -170,8 +169,10 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment):
__match_args__ = ("declared_variable", "rhs",)
__match_args__ = (
"declared_variable",
"rhs",
)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs)
......@@ -203,7 +204,6 @@ class PsDeclaration(PsAssignment):
class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body")
def __init__(
......
class PsInternalCompilerError(Exception):
pass
class PsMalformedAstException(Exception):
pass
......@@ -226,13 +226,13 @@ class PsTypedConstant:
def __rsub__(self, other: Any):
return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
@staticmethod
def _divrem(dividend, divisor):
quotient = abs(dividend) // abs(divisor)
quotient = quotient if (dividend * divisor > 0) else (- quotient)
quotient = abs(dividend) // abs(divisor)
quotient = quotient if (dividend * divisor > 0) else (-quotient)
rem = abs(dividend) % abs(divisor)
rem = rem if dividend >= 0 else (- rem)
rem = rem if dividend >= 0 else (-rem)
return quotient, rem
def __truediv__(self, other: Any):
......@@ -274,7 +274,7 @@ class PsTypedConstant:
def __neg__(self):
minus_one = PsTypedConstant(-1, self._dtype)
return pb.Product((minus_one, self))
def __bool__(self):
return bool(self._value)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment