diff --git a/pystencils/nbackend/ast/dispatcher.py b/pystencils/nbackend/ast/dispatcher.py index a27f41bcf51ef68ccb1ef9c11860f1845e5ae311..e38485cb068c67dd01fb1a5958a9c689eab60b53 100644 --- a/pystencils/nbackend/ast/dispatcher.py +++ b/pystencils/nbackend/ast/dispatcher.py @@ -6,6 +6,7 @@ from functools import wraps from .nodes import PsAstNode + class VisitorDispatcher: def __init__(self, wrapped_method): self._dispatch_dict = {} @@ -13,7 +14,7 @@ class VisitorDispatcher: def case(self, node_type: type): """Decorator for visitor's methods""" - + def decorate(handler: Callable): if node_type in self._dispatch_dict: raise ValueError(f"Duplicate visitor case {node_type}") @@ -37,4 +38,3 @@ class VisitorDispatcher: def ast_visitor(method): return wraps(method)(VisitorDispatcher(method)) - diff --git a/pystencils/nbackend/ast/transformations.py b/pystencils/nbackend/ast/transformations.py index 0bf23c0c7d3326bc8eb2041b6117ee760b06e375..41a4d9e27e2157fddee5b48289bc607398c196a7 100644 --- a/pystencils/nbackend/ast/transformations.py +++ b/pystencils/nbackend/ast/transformations.py @@ -3,7 +3,7 @@ from abc import ABC from typing import Dict from pymbolic.primitives import Expression -from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func +from pymbolic.mapper.substitutor import CachedSubstitutionMapper from ..typed_expressions import PsTypedSymbol from .dispatcher import ast_visitor @@ -23,7 +23,7 @@ class PsAstTransformer(ABC): class PsSymbolsSubstitutor(PsAstTransformer): def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]): self._subs_dict = subs_dict - self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None)) + self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None)) def subs(self, node: PsAstNode): return self.visit(node) @@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer): raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.") self.transform_children(loop) return loop - + @visit.case(PsExpression) def expression(self, expr_node: PsExpression): self._mapper(expr_node.expression) diff --git a/pystencils/nbackend/typed_expressions.py b/pystencils/nbackend/typed_expressions.py index 055a0b904f365b7a5c4dc9cec15401a18d13cf42..d418e875dbec0a2587a9df84733134c97ac050c7 100644 --- a/pystencils/nbackend/typed_expressions.py +++ b/pystencils/nbackend/typed_expressions.py @@ -6,6 +6,7 @@ import pymbolic.primitives as pb from ..typing import AbstractType, BasicType + class PsTypedSymbol(pb.Variable): def __init__(self, name: str, dtype: AbstractType): super(PsTypedSymbol, self).__init__(name) @@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript): super(PsArrayAccess, self).__init__(base_ptr, index) -PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess] +PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess] class PsTypedConstant: @@ -37,36 +38,36 @@ class PsTypedConstant: if value._dtype != target_dtype: raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}") return value - + # TODO check legality return PsTypedConstant(value, target_dtype) def __init__(self, value, dtype: AbstractType): """Represents typed constants occuring in the pystencils AST""" if isinstance(dtype, BasicType): - dtype = BasicType(dtype, const = True) + dtype = BasicType(dtype, const=True) self._value = dtype.numpy_dtype.type(value) else: raise ValueError(f"Cannot create constant of type {dtype}") - + self._dtype = dtype def __str__(self) -> str: return str(self._value) - + def __add__(self, other: Any): other = PsTypedConstant._cast(other, self._dtype) - + return PsTypedConstant(self._value + other._value, self._dtype) def __mul__(self, other: Any): other = PsTypedConstant._cast(other, self._dtype) - + return PsTypedConstant(self._value * other._value, self._dtype) - + def __sub__(self, other: Any): other = PsTypedConstant._cast(other, self._dtype) - + return PsTypedConstant(self._value - other._value, self._dtype) # TODO: Remaining operators