Skip to content
Snippets Groups Projects
Commit 9d56ac14 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

code style

parent e3757c08
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@ from functools import wraps
from .nodes import PsAstNode
class VisitorDispatcher:
def __init__(self, wrapped_method):
self._dispatch_dict = {}
......@@ -13,7 +14,7 @@ class VisitorDispatcher:
def case(self, node_type: type):
"""Decorator for visitor's methods"""
def decorate(handler: Callable):
if node_type in self._dispatch_dict:
raise ValueError(f"Duplicate visitor case {node_type}")
......@@ -37,4 +38,3 @@ class VisitorDispatcher:
def ast_visitor(method):
return wraps(method)(VisitorDispatcher(method))
......@@ -3,7 +3,7 @@ from abc import ABC
from typing import Dict
from pymbolic.primitives import Expression
from pymbolic.mapper.substitutor import CachedSubstitutionMapper, make_subst_func
from pymbolic.mapper.substitutor import CachedSubstitutionMapper
from ..typed_expressions import PsTypedSymbol
from .dispatcher import ast_visitor
......@@ -23,7 +23,7 @@ class PsAstTransformer(ABC):
class PsSymbolsSubstitutor(PsAstTransformer):
def __init__(self, subs_dict: Dict[PsTypedSymbol, Expression]):
self._subs_dict = subs_dict
self._mapper = CachedSubstitutionMapper(lambda s : self._subs_dict.get(s, None))
self._mapper = CachedSubstitutionMapper(lambda s: self._subs_dict.get(s, None))
def subs(self, node: PsAstNode):
return self.visit(node)
......@@ -44,7 +44,7 @@ class PsSymbolsSubstitutor(PsAstTransformer):
raise ValueError(f"Cannot substitute symbol {loop.counter.expression} that is defined as a loop counter.")
self.transform_children(loop)
return loop
@visit.case(PsExpression)
def expression(self, expr_node: PsExpression):
self._mapper(expr_node.expression)
......
......@@ -6,6 +6,7 @@ import pymbolic.primitives as pb
from ..typing import AbstractType, BasicType
class PsTypedSymbol(pb.Variable):
def __init__(self, name: str, dtype: AbstractType):
super(PsTypedSymbol, self).__init__(name)
......@@ -26,7 +27,7 @@ class PsArrayAccess(pb.Subscript):
super(PsArrayAccess, self).__init__(base_ptr, index)
PsLvalue : TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
PsLvalue: TypeAlias = Union[PsTypedSymbol, PsArrayAccess]
class PsTypedConstant:
......@@ -37,36 +38,36 @@ class PsTypedConstant:
if value._dtype != target_dtype:
raise ValueError(f"Incompatible types: {value._dtype} and {target_dtype}")
return value
# TODO check legality
return PsTypedConstant(value, target_dtype)
def __init__(self, value, dtype: AbstractType):
"""Represents typed constants occuring in the pystencils AST"""
if isinstance(dtype, BasicType):
dtype = BasicType(dtype, const = True)
dtype = BasicType(dtype, const=True)
self._value = dtype.numpy_dtype.type(value)
else:
raise ValueError(f"Cannot create constant of type {dtype}")
self._dtype = dtype
def __str__(self) -> str:
return str(self._value)
def __add__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value + other._value, self._dtype)
def __mul__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value * other._value, self._dtype)
def __sub__(self, other: Any):
other = PsTypedConstant._cast(other, self._dtype)
return PsTypedConstant(self._value - other._value, self._dtype)
# TODO: Remaining operators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment