Skip to content
Snippets Groups Projects
Commit 9d56ac14 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code style

parent e3757c08
No related merge requests found
...@@ -6,6 +6,7 @@ from functools import wraps ...@@ -6,6 +6,7 @@ from functools import wraps
from .nodes import PsAstNode from .nodes import PsAstNode
class VisitorDispatcher: class VisitorDispatcher:
def __init__(self, wrapped_method): def __init__(self, wrapped_method):
self._dispatch_dict = {} self._dispatch_dict = {}
...@@ -13,7 +14,7 @@ class VisitorDispatcher: ...@@ -13,7 +14,7 @@ class VisitorDispatcher:
def case(self, node_type: type): def case(self, node_type: type):
"""Decorator for visitor's methods""" """Decorator for visitor's methods"""
def decorate(handler: Callable): def decorate(handler: Callable):
if node_type in self._dispatch_dict: if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}") raise ValueError(f"Duplicate visitor case {node_type}")
...@@ -37,4 +38,3 @@ class VisitorDispatcher: ...@@ -37,4 +38,3 @@ class VisitorDispatcher:
def ast_visitor(method): def ast_visitor(method):
return wraps(method)(VisitorDispatcher(method)) return wraps(method)(VisitorDispatcher(method))
...@@ -3,7 +3,7 @@ from abc import ABC ...@@ -3,7 +3,7 @@ from abc import ABC
from typing import Dict from typing import Dict
from pymbolic.primitives import Expression from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func from pymbolic.mapper.substitutor import CachedSubstitutionMapper
from ..typed_expressions import PsTypedSymbol from ..typed_expressions import PsTypedSymbol
from .dispatcher import ast_visitor from .dispatcher import ast_visitor
...@@ -23,7 +23,7 @@ class PsAstTransformer(ABC): ...@@ -23,7 +23,7 @@ class PsAstTransformer(ABC):
class PsSymbolsSubstitutor(PsAstTransformer): class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]): def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
self._subs_dict = subs_dict self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None)) self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
def subs(self, node: PsAstNode): def subs(self, node: PsAstNode):
return self.visit(node) return self.visit(node)
...@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer): ...@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer):
raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.") raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.")
self.transform_children(loop) self.transform_children(loop)
return loop return loop
@visit.case(PsExpression) @visit.case(PsExpression)
def expression(self, expr_node: PsExpression): def expression(self, expr_node: PsExpression):
self._mapper(expr_node.expression) self._mapper(expr_node.expression)
......
...@@ -6,6 +6,7 @@ import pymbolic.primitives as pb ...@@ -6,6 +6,7 @@ import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable): class PsTypedSymbol(pb.Variable):
def __init__(self, name: str, dtype: AbstractType): def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name) super(PsTypedSymbol, self).__init__(name)
...@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript): ...@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript):
super(PsArrayAccess, self).__init__(base_ptr, index) super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess] PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
class PsTypedConstant: class PsTypedConstant:
...@@ -37,36 +38,36 @@ class PsTypedConstant: ...@@ -37,36 +38,36 @@ class PsTypedConstant:
if value._dtype != target_dtype: if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}") raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
return value return value
# TODO check legality # TODO check legality
return PsTypedConstant(value, target_dtype) return PsTypedConstant(value, target_dtype)
def __init__(self, value, dtype: AbstractType): def __init__(self, value, dtype: AbstractType):
"""Represents typed constants occuring in the pystencils AST""" """Represents typed constants occuring in the pystencils AST"""
if isinstance(dtype, BasicType): if isinstance(dtype, BasicType):
dtype = BasicType(dtype, const = True) dtype = BasicType(dtype, const=True)
self._value = dtype.numpy_dtype.type(value) self._value = dtype.numpy_dtype.type(value)
else: else:
raise ValueError(f"Cannot create constant of type {dtype}") raise ValueError(f"Cannot create constant of type {dtype}")
self._dtype = dtype self._dtype = dtype
def __str__(self) -> str: def __str__(self) -> str:
return str(self._value) return str(self._value)
def __add__(self, other: Any): def __add__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value + other._value, self._dtype) return PsTypedConstant(self._value + other._value, self._dtype)
def __mul__(self, other: Any): def __mul__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value * other._value, self._dtype) return PsTypedConstant(self._value * other._value, self._dtype)
def __sub__(self, other: Any): def __sub__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value - other._value, self._dtype) return PsTypedConstant(self._value - other._value, self._dtype)
# TODO: Remaining operators # TODO: Remaining operators
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment