Skip to content
Snippets Groups Projects
Commit 9d56ac14 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code style

parent e3757c08
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from functools import wraps ...@@ -6,6 +6,7 @@ from functools import wraps
from .nodes import PsAstNode from .nodes import PsAstNode
class VisitorDispatcher: class VisitorDispatcher:
def __init__(self, wrapped_method): def __init__(self, wrapped_method):
self._dispatch_dict = {} self._dispatch_dict = {}
...@@ -13,7 +14,7 @@ class VisitorDispatcher: ...@@ -13,7 +14,7 @@ class VisitorDispatcher:
def case(self, node_type: type): def case(self, node_type: type):
"""Decorator for visitor's methods""" """Decorator for visitor's methods"""
def decorate(handler: Callable): def decorate(handler: Callable):
if node_type in self._dispatch_dict: if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}") raise ValueError(f"Duplicate visitor case {node_type}")
...@@ -37,4 +38,3 @@ class VisitorDispatcher: ...@@ -37,4 +38,3 @@ class VisitorDispatcher:
def ast_visitor(method): def ast_visitor(method):
return wraps(method)(VisitorDispatcher(method)) return wraps(method)(VisitorDispatcher(method))
...@@ -3,7 +3,7 @@ from abc import ABC ...@@ -3,7 +3,7 @@ from abc import ABC
from typing import Dict from typing import Dict
from pymbolic.primitives import Expression from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func from pymbolic.mapper.substitutor import CachedSubstitutionMapper
from ..typed_expressions import PsTypedSymbol from ..typed_expressions import PsTypedSymbol
from .dispatcher import ast_visitor from .dispatcher import ast_visitor
...@@ -23,7 +23,7 @@ class PsAstTransformer(ABC): ...@@ -23,7 +23,7 @@ class PsAstTransformer(ABC):
class PsSymbolsSubstitutor(PsAstTransformer): class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]): def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
self._subs_dict = subs_dict self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None)) self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
def subs(self, node: PsAstNode): def subs(self, node: PsAstNode):
return self.visit(node) return self.visit(node)
...@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer): ...@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer):
raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.") raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.")
self.transform_children(loop) self.transform_children(loop)
return loop return loop
@visit.case(PsExpression) @visit.case(PsExpression)
def expression(self, expr_node: PsExpression): def expression(self, expr_node: PsExpression):
self._mapper(expr_node.expression) self._mapper(expr_node.expression)
......
...@@ -6,6 +6,7 @@ import pymbolic.primitives as pb ...@@ -6,6 +6,7 @@ import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable): class PsTypedSymbol(pb.Variable):
def __init__(self, name: str, dtype: AbstractType): def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name) super(PsTypedSymbol, self).__init__(name)
...@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript): ...@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript):
super(PsArrayAccess, self).__init__(base_ptr, index) super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess] PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
class PsTypedConstant: class PsTypedConstant:
...@@ -37,36 +38,36 @@ class PsTypedConstant: ...@@ -37,36 +38,36 @@ class PsTypedConstant:
if value._dtype != target_dtype: if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}") raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
return value return value
# TODO check legality # TODO check legality
return PsTypedConstant(value, target_dtype) return PsTypedConstant(value, target_dtype)
def __init__(self, value, dtype: AbstractType): def __init__(self, value, dtype: AbstractType):
"""Represents typed constants occuring in the pystencils AST""" """Represents typed constants occuring in the pystencils AST"""
if isinstance(dtype, BasicType): if isinstance(dtype, BasicType):
dtype = BasicType(dtype, const = True) dtype = BasicType(dtype, const=True)
self._value = dtype.numpy_dtype.type(value) self._value = dtype.numpy_dtype.type(value)
else: else:
raise ValueError(f"Cannot create constant of type {dtype}") raise ValueError(f"Cannot create constant of type {dtype}")
self._dtype = dtype self._dtype = dtype
def __str__(self) -> str: def __str__(self) -> str:
return str(self._value) return str(self._value)
def __add__(self, other: Any): def __add__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value + other._value, self._dtype) return PsTypedConstant(self._value + other._value, self._dtype)
def __mul__(self, other: Any): def __mul__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value * other._value, self._dtype) return PsTypedConstant(self._value * other._value, self._dtype)
def __sub__(self, other: Any): def __sub__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype) other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value - other._value, self._dtype) return PsTypedConstant(self._value - other._value, self._dtype)
# TODO: Remaining operators # TODO: Remaining operators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment