Skip to content
Snippets Groups Projects
Commit 8cbcc472 authored by Markus Holzer's avatar Markus Holzer
Browse files

Started refactor

parent 1fadc0dd
No related branches found
No related tags found
No related merge requests found
Pipeline #62629 failed
Showing
with 62 additions and 9 deletions
...@@ -2,18 +2,17 @@ ...@@ -2,18 +2,17 @@
from .enums import Backend, Target from .enums import Backend, Target
from . import fd from . import fd
from . import stencil as stencil from . import stencil as stencil
from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil from pystencils.sympyextensions.assignmentcollection.assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .typing.typed_sympy import TypedSymbol
from .display_utils import get_code_obj, get_code_str, show_code, to_dot from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields from .field import Field, FieldType, fields
from .config import CreateKernelConfig from .config import CreateKernelConfig
from .cache import clear_cache from .cache import clear_cache
from .kernel_decorator import kernel, kernel_config from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel, create_staggered_kernel from .kernelcreation import create_kernel
from .simp import AssignmentCollection from pystencils.sympyextensions.assignmentcollection import AssignmentCollection
from .slicing import make_slice from .slicing import make_slice
from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
from .sympyextensions import SymbolCreator from .sympyextensions.math import SymbolCreator
from .datahandling import create_data_handling from .datahandling import create_data_handling
__all__ = ['Field', 'FieldType', 'fields', __all__ = ['Field', 'FieldType', 'fields',
......
from __future__ import annotations from __future__ import annotations
from typing import Callable
from types import MethodType
from functools import wraps from functools import wraps
from typing import Callable
from types import MethodType
from .nodes import PsAstNode from .nodes import PsAstNode
......
...@@ -9,9 +9,8 @@ import sympy as sp ...@@ -9,9 +9,8 @@ import sympy as sp
from .context import KernelCreationContext from .context import KernelCreationContext
from ...field import Field from ...field import Field
from ...assignment import Assignment from pystencils.sympyextensions.assignmentcollection.assignment import Assignment
from ...simp import AssignmentCollection from ...simp import AssignmentCollection
from ...transformations import NestedScopes
from ..exceptions import PsInternalCompilerError, KernelConstraintsError from ..exceptions import PsInternalCompilerError, KernelConstraintsError
...@@ -156,3 +155,57 @@ class KernelAnalysis: ...@@ -156,3 +155,57 @@ class KernelAnalysis:
rec(arg) rec(arg)
rec(rhs) rec(rhs)
class NestedScopes:
"""Symbol visibility model using nested scopes
- every accessed symbol that was not defined before, is added as a "free parameter"
- free parameters are global, i.e. they are not in scopes
- push/pop adds or removes a scope
>>> s = NestedScopes()
>>> s.access_symbol("a")
>>> s.is_defined("a")
False
>>> s.free_parameters
{'a'}
>>> s.define_symbol("b")
>>> s.is_defined("b")
True
>>> s.push()
>>> s.is_defined_locally("b")
False
>>> s.define_symbol("c")
>>> s.pop()
>>> s.is_defined("c")
False
"""
def __init__(self):
self.free_parameters = set()
self._defined = [set()]
def access_symbol(self, symbol):
if not self.is_defined(symbol):
self.free_parameters.add(symbol)
def define_symbol(self, symbol):
self._defined[-1].add(symbol)
def is_defined(self, symbol):
return any(symbol in scopes for scopes in self._defined)
def is_defined_locally(self, symbol):
return symbol in self._defined[-1]
def push(self):
self._defined.append(set())
def pop(self):
self._defined.pop()
assert self.depth >= 1
@property
def depth(self):
return len(self._defined)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment