Skip to content
Snippets Groups Projects
Commit 9d56ac14 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code style

parent e3757c08
Branches
Tags
No related merge requests found
......@@ -6,6 +6,7 @@ from functools import wraps
from .nodes import PsAstNode
class VisitorDispatcher:
def __init__(self, wrapped_method):
self._dispatch_dict = {}
......@@ -13,7 +14,7 @@ class VisitorDispatcher:
def case(self, node_type: type):
"""Decorator for visitor's methods"""
def decorate(handler: Callable):
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
......@@ -37,4 +38,3 @@ class VisitorDispatcher:
def ast_visitor(method):
return wraps(method)(VisitorDispatcher(method))
......@@ -3,7 +3,7 @@ from abc import ABC
from typing import Dict
from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func
from pymbolic.mapper.substitutor import CachedSubstitutionMapper
from ..typed_expressions import PsTypedSymbol
from .dispatcher import ast_visitor
......@@ -23,7 +23,7 @@ class PsAstTransformer(ABC):
class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None))
self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
def subs(self, node: PsAstNode):
return self.visit(node)
......@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer):
raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.")
self.transform_children(loop)
return loop
@visit.case(PsExpression)
def expression(self, expr_node: PsExpression):
self._mapper(expr_node.expression)
......
......@@ -6,6 +6,7 @@ import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable):
def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name)
......@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript):
super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
class PsTypedConstant:
......@@ -37,36 +38,36 @@ class PsTypedConstant:
if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
return value
# TODO check legality
return PsTypedConstant(value, target_dtype)
def __init__(self, value, dtype: AbstractType):
"""Represents typed constants occuring in the pystencils AST"""
if isinstance(dtype, BasicType):
dtype = BasicType(dtype, const = True)
dtype = BasicType(dtype, const=True)
self._value = dtype.numpy_dtype.type(value)
else:
raise ValueError(f"Cannot create constant of type {dtype}")
self._dtype = dtype
def __str__(self) -> str:
return str(self._value)
def __add__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value + other._value, self._dtype)
def __mul__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value * other._value, self._dtype)
def __sub__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value - other._value, self._dtype)
# TODO: Remaining operators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment