Skip to content
Snippets Groups Projects
Commit 1397bcb2 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

some minor API changes

parent 2ba2fd8d
No related branches found
No related tags found
No related merge requests found
Pipeline #70100 failed
......@@ -14,7 +14,7 @@ from ..composer import (
SfgComposer,
SfgComposerMixIn,
)
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude, SfgKernelParamVar
from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude
from ..ir import (
SfgCallTreeNode,
SfgCallTreeLeaf,
......@@ -73,7 +73,7 @@ class SyclHandler(AugExpr):
id_regex = re.compile(r"sycl::(id|item|nd_item)<\s*[0-9]\s*>")
def filter_id(param: SfgKernelParamVar) -> bool:
def filter_id(param: SfgVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......@@ -117,7 +117,7 @@ class SyclGroup(AugExpr):
id_regex = re.compile(r"sycl::id<\s*[0-9]\s*>")
def filter_id(param: SfgKernelParamVar) -> bool:
def filter_id(param: SfgVar) -> bool:
return (
isinstance(param.dtype, PsCustomType)
and id_regex.search(param.dtype.c_string()) is not None
......
......@@ -163,7 +163,7 @@ class SfgKernelHandle:
self._namespace = namespace
self._parameters = [SfgKernelParamVar(p) for p in parameters]
self._scalar_params: set[SfgKernelParamVar] = set()
self._scalar_params: set[SfgVar] = set()
self._fields: set[Field] = set()
for param in self._parameters:
......@@ -193,7 +193,7 @@ class SfgKernelHandle:
return self._parameters
@property
def scalar_parameters(self):
def scalar_parameters(self) -> set[SfgVar]:
return self._scalar_params
@property
......
......@@ -12,7 +12,7 @@ from .expressions import (
SrcVector,
)
from .types import Ref
from .types import Ref, strip_ptr_ref
__all__ = [
"SfgVar",
......@@ -27,4 +27,5 @@ __all__ = [
"SrcField",
"SrcVector",
"Ref",
"strip_ptr_ref"
]
......@@ -174,23 +174,30 @@ class AugExpr:
"""Create a new `AugExpr` by combining existing expressions."""
return AugExpr().bind(fmt, *deps, **kwdeps)
def bind(self, fmt: str, *deps, **kwdeps):
dependencies: set[SfgVar] = set()
from pystencils.sympyextensions import is_constant
for expr in chain(deps, kwdeps.values()):
if isinstance(expr, _ExprLike):
dependencies |= depends(expr)
elif isinstance(expr, sp.Expr) and not is_constant(expr):
raise ValueError(
f"Cannot parse SymPy expression as C++ expression: {expr}\n"
" * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
"since they contain symbols without type information."
)
code = fmt.format(*deps, **kwdeps)
self._bind(DependentExpression(code, dependencies))
def bind(self, fmt: str | AugExpr, *deps, **kwdeps):
if isinstance(fmt, AugExpr):
if bool(deps) or bool(kwdeps):
raise ValueError("Binding to another AugExpr does not permit additional arguments")
if fmt._bound is None:
raise ValueError("Cannot rebind to unbound AugExpr.")
self._bind(fmt._bound)
else:
dependencies: set[SfgVar] = set()
from pystencils.sympyextensions import is_constant
for expr in chain(deps, kwdeps.values()):
if isinstance(expr, _ExprLike):
dependencies |= depends(expr)
elif isinstance(expr, sp.Expr) and not is_constant(expr):
raise ValueError(
f"Cannot parse SymPy expression as C++ expression: {expr}\n"
" * pystencils-sfg is currently unable to parse non-constant SymPy expressions "
"since they contain symbols without type information."
)
code = fmt.format(*deps, **kwdeps)
self._bind(DependentExpression(code, dependencies))
return self
def expr(self) -> DependentExpression:
......
from typing import Any
from pystencils.types import PsType
from pystencils.types import PsType, PsPointerType
class Ref(PsType):
......@@ -24,3 +24,13 @@ class Ref(PsType):
def __repr__(self) -> str:
return f"Ref({repr(self.base_type)})"
def strip_ptr_ref(dtype: PsType):
match dtype:
case Ref():
return strip_ptr_ref(dtype.base_type)
case PsPointerType():
return strip_ptr_ref(dtype.base_type)
case _:
return dtype
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment