Skip to content
Snippets Groups Projects
Commit 20d37fba authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code cleanup

parent 679bf618
No related merge requests found
Pipeline #60171 failed with stages
in 3 minutes and 53 seconds
from functools import reduce from typing import cast
from typing import Any, cast
from pymbolic.primitives import Variable from pymbolic.primitives import Variable
from pymbolic.mapper.dependency import DependencyMapper from pymbolic.mapper.dependency import DependencyMapper
...@@ -47,7 +46,7 @@ class UndefinedVariablesCollector: ...@@ -47,7 +46,7 @@ class UndefinedVariablesCollector:
return self.collect(lhs) | self.collect(rhs) return self.collect(lhs) | self.collect(rhs)
case PsBlock(statements): case PsBlock(statements):
undefined_vars = set() undefined_vars: set[PsTypedVariable] = set()
for stmt in statements[::-1]: for stmt in statements[::-1]:
undefined_vars -= self.declared_variables(stmt) undefined_vars -= self.declared_variables(stmt)
undefined_vars |= self.collect(stmt) undefined_vars |= self.collect(stmt)
......
...@@ -5,6 +5,7 @@ from .nodes import PsAstNode, PsBlock, failing_cast ...@@ -5,6 +5,7 @@ from .nodes import PsAstNode, PsBlock, failing_cast
from ..typed_expressions import PsTypedVariable from ..typed_expressions import PsTypedVariable
from ...enums import Target from ...enums import Target
class PsKernelFunction(PsAstNode): class PsKernelFunction(PsAstNode):
"""A complete pystencils kernel function.""" """A complete pystencils kernel function."""
...@@ -19,11 +20,11 @@ class PsKernelFunction(PsAstNode): ...@@ -19,11 +20,11 @@ class PsKernelFunction(PsAstNode):
def target(self) -> Target: def target(self) -> Target:
"""See pystencils.Target""" """See pystencils.Target"""
return self._target return self._target
@property @property
def body(self) -> PsBlock: def body(self) -> PsBlock:
return self._body return self._body
@body.setter @body.setter
def body(self, body: PsBlock): def body(self, body: PsBlock):
self._body = body self._body = body
...@@ -31,16 +32,16 @@ class PsKernelFunction(PsAstNode): ...@@ -31,16 +32,16 @@ class PsKernelFunction(PsAstNode):
@property @property
def name(self) -> str: def name(self) -> str:
return self._name return self._name
@name.setter @name.setter
def name(self, value: str): def name(self, value: str):
self._name = value self._name = value
def num_children(self) -> int: def num_children(self) -> int:
return 1 return 1
def children(self) -> Generator[PsAstNode, None, None]: def children(self) -> Generator[PsAstNode, None, None]:
yield from (self._body, ) yield from (self._body,)
def get_child(self, idx: int): def get_child(self, idx: int):
if idx not in (0, -1): if idx not in (0, -1):
...@@ -51,14 +52,14 @@ class PsKernelFunction(PsAstNode): ...@@ -51,14 +52,14 @@ class PsKernelFunction(PsAstNode):
if idx not in (0, -1): if idx not in (0, -1):
raise IndexError(f"Child index out of bounds: {idx}") raise IndexError(f"Child index out of bounds: {idx}")
self._body = failing_cast(PsBlock, c) self._body = failing_cast(PsBlock, c)
def get_parameters(self) -> Sequence[PsTypedVariable]: def get_parameters(self) -> Sequence[PsTypedVariable]:
"""Collect the list of parameters to this function. """Collect the list of parameters to this function.
This function performs a full traversal of the AST. This function performs a full traversal of the AST.
To improve performance, make sure to cache the result if necessary. To improve performance, make sure to cache the result if necessary.
""" """
from .analysis import UndefinedVariablesCollector from .analysis import UndefinedVariablesCollector
params = UndefinedVariablesCollector().collect(self) params = UndefinedVariablesCollector().collect(self)
return sorted(params, key=lambda p: p.name) return sorted(params, key=lambda p: p.name)
\ No newline at end of file
...@@ -3,8 +3,6 @@ from typing import Sequence, Generator, Iterable, cast ...@@ -3,8 +3,6 @@ from typing import Sequence, Generator, Iterable, cast
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import pymbolic.primitives as pb
from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant from ..typed_expressions import PsTypedVariable, PsArrayAccess, PsLvalue, ExprOrConstant
from .util import failing_cast from .util import failing_cast
...@@ -41,7 +39,6 @@ class PsAstNode(ABC): ...@@ -41,7 +39,6 @@ class PsAstNode(ABC):
class PsBlock(PsAstNode): class PsBlock(PsAstNode):
__match_args__ = ("statements",) __match_args__ = ("statements",)
def __init__(self, cs: Sequence[PsAstNode]): def __init__(self, cs: Sequence[PsAstNode]):
...@@ -62,9 +59,9 @@ class PsBlock(PsAstNode): ...@@ -62,9 +59,9 @@ class PsBlock(PsAstNode):
@property @property
def statements(self) -> list[PsAstNode]: def statements(self) -> list[PsAstNode]:
return self._statements return self._statements
@statements.setter @statements.setter
def statemetns(self, stm: Sequence[PsAstNode]): def statements(self, stm: Sequence[PsAstNode]):
self._statements = list(stm) self._statements = list(stm)
...@@ -127,8 +124,10 @@ class PsSymbolExpr(PsLvalueExpr): ...@@ -127,8 +124,10 @@ class PsSymbolExpr(PsLvalueExpr):
class PsAssignment(PsAstNode): class PsAssignment(PsAstNode):
__match_args__ = (
__match_args__ = ("lhs", "rhs",) "lhs",
"rhs",
)
def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression): def __init__(self, lhs: PsLvalueExpr, rhs: PsExpression):
self._lhs = lhs self._lhs = lhs
...@@ -170,8 +169,10 @@ class PsAssignment(PsAstNode): ...@@ -170,8 +169,10 @@ class PsAssignment(PsAstNode):
class PsDeclaration(PsAssignment): class PsDeclaration(PsAssignment):
__match_args__ = (
__match_args__ = ("declared_variable", "rhs",) "declared_variable",
"rhs",
)
def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression): def __init__(self, lhs: PsSymbolExpr, rhs: PsExpression):
super().__init__(lhs, rhs) super().__init__(lhs, rhs)
...@@ -203,7 +204,6 @@ class PsDeclaration(PsAssignment): ...@@ -203,7 +204,6 @@ class PsDeclaration(PsAssignment):
class PsLoop(PsAstNode): class PsLoop(PsAstNode):
__match_args__ = ("counter", "start", "stop", "step", "body") __match_args__ = ("counter", "start", "stop", "step", "body")
def __init__( def __init__(
......
class PsInternalCompilerError(Exception): class PsInternalCompilerError(Exception):
pass pass
class PsMalformedAstException(Exception): class PsMalformedAstException(Exception):
pass pass
...@@ -226,13 +226,13 @@ class PsTypedConstant: ...@@ -226,13 +226,13 @@ class PsTypedConstant:
def __rsub__(self, other: Any): def __rsub__(self, other: Any):
return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype) return PsTypedConstant(self._rfix(other)._value - self._value, self._dtype)
@staticmethod @staticmethod
def _divrem(dividend, divisor): def _divrem(dividend, divisor):
quotient = abs(dividend) // abs(divisor) quotient = abs(dividend) // abs(divisor)
quotient = quotient if (dividend * divisor > 0) else (- quotient) quotient = quotient if (dividend * divisor > 0) else (-quotient)
rem = abs(dividend) % abs(divisor) rem = abs(dividend) % abs(divisor)
rem = rem if dividend >= 0 else (- rem) rem = rem if dividend >= 0 else (-rem)
return quotient, rem return quotient, rem
def __truediv__(self, other: Any): def __truediv__(self, other: Any):
...@@ -274,7 +274,7 @@ class PsTypedConstant: ...@@ -274,7 +274,7 @@ class PsTypedConstant:
def __neg__(self): def __neg__(self):
minus_one = PsTypedConstant(-1, self._dtype) minus_one = PsTypedConstant(-1, self._dtype)
return pb.Product((minus_one, self)) return pb.Product((minus_one, self))
def __bool__(self): def __bool__(self):
return bool(self._value) return bool(self._value)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment