diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index cbaa0e6790a2da0ee6f4f8b99986bcc74d14876b..2f84f29f8f35e73671be06e725923700bdc31dc9 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -33,15 +33,22 @@ testsuite: stage: "Tests" image: i10git.cs.fau.de:5005/pycodegen/pycodegen/full needs: [] + tags: + - docker before_script: - pip install "git+https://i10git.cs.fau.de/pycodegen/pystencils.git@v2.0-dev" - pip install -e . script: - - pytest -v --cov-report html --cov-report xml --cov-report term --cov=src/pystencilssfg tests + - coverage run -m pytest -v + - coverage report + - coverage html + - coverage xml + coverage: '/TOTAL.*\s+(\d+%)$/' artifacts: when: always paths: - htmlcov + - coverage.xml reports: coverage_report: coverage_format: cobertura diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ef8905c158e1336ca6909f9c5273cbafc9a39250..dbcaaf155503325f1282cd74f7ba013ee5693221 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -15,8 +15,8 @@ As such, any submission of contributions via merge requests is considered as agr ### Fork and Clone -To work within the `pystencils-sfg` source tree, first create a *fork* of this repository on GitLab and create -a local clone of your fork. +To work within the `pystencils-sfg` source tree, first create a *fork* of this repository +and clone it to your workstation. ### Set up your dev environment @@ -52,3 +52,16 @@ Both `flake8` and `mypy` are also run in the integration pipeline. You can automate the code quality checks by running them via a git pre-commit hook. Such a hook can be installed using the [`install_git_hooks.sh`](install_git_hooks.sh) script located at the project root. +### Test Your Code + +We are working toward near-complete test coverage of the module source files. +When you add code, make sure to include test cases for both its desired +and exceptional behavior at the appropriate locations in the [tests](tests) directory. + +Unit tests should be placed under a path and filename mirroring the location +of the API they are testing within the *pystencils-sfg* source tree. + +In [tests/generator_scripts](tests/generator_scripts), a framework is provided to test entire generator scripts +for successful execution, correctness, and compilability of their output. +Read the documentation within [test_generator_scripts.py](tests/generator_scripts/test_generator_scripts.py) +for more information. diff --git a/README.md b/README.md index ec8972cf4da36d32b89095e36242b0b64ec12bfc..9a2dd719e6435b089e34c4de76ff1ce2c96ea5c5 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,9 @@ # pystencils Source File Generator (pystencils-sfg) -[](https://pycodegen.pages.i10git.cs.fau.de/pystencils-sfg) -[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/commits/master) -[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/-/blob/master/LICENSE) +[](https://pycodegen.pages.i10git.cs.fau.de/pystencils-sfg) +[](https://i10git.cs.fau.de/pycodegen-/pystencils-sfg/commits/master) + +[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/-/blob/master/LICENSE) A bridge over the semantic gap between code emitted by pystencils and your C/C++/Cuda/HIP framework. diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..e1d0cdda1b64ec9eff93367589fe3a08d975deb7 --- /dev/null +++ b/conftest.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.fixture(autouse=True) +def prepare_composer(doctest_namespace): + from pystencilssfg import SfgContext, SfgComposer + + # Place a composer object in the environment for doctests + + sfg = SfgComposer(SfgContext()) + doctest_namespace["sfg"] = sfg diff --git a/docs/source/api/composer.rst b/docs/source/api/composer.rst index 75acd8abedc7ebe0c9c9024575f3bffe99f7c6bf..a969d4ff91455f0fda51813aab93e23c4b0b7098 100644 --- a/docs/source/api/composer.rst +++ b/docs/source/api/composer.rst @@ -14,4 +14,23 @@ Composer API (`pystencilssfg.composer`) .. autoclass:: pystencilssfg.composer.SfgClassComposer :members: +Custom Generators +================= + +.. autoclass:: pystencilssfg.composer.custom.CustomGenerator + :members: + + +Helper Methods and Builders +=========================== + .. autofunction:: pystencilssfg.composer.make_sequence + +.. autoclass:: pystencilssfg.composer.basic_composer.SfgNodeBuilder + :members: + +.. autoclass:: pystencilssfg.composer.basic_composer.SfgBranchBuilder + :members: + +.. autoclass:: pystencilssfg.composer.basic_composer.SfgSwitchBuilder + :members: diff --git a/docs/source/api/errors.rst b/docs/source/api/errors.rst new file mode 100644 index 0000000000000000000000000000000000000000..c793641baaf03069c32761f47e0df33e70afeb09 --- /dev/null +++ b/docs/source/api/errors.rst @@ -0,0 +1,6 @@ +********************* +Errors and Exceptions +********************* + +.. automodule:: pystencilssfg.exceptions + :members: diff --git a/docs/source/api/generation.rst b/docs/source/api/generation.rst index 45065c13edc444239e928ecb66a376bc700e7959..f15b1170099cfb10c9f333cb6f4540137e162654 100644 --- a/docs/source/api/generation.rst +++ b/docs/source/api/generation.rst @@ -8,4 +8,10 @@ Generator Script Interface .. autoclass:: pystencilssfg.SfgConfiguration :members: +.. autoclass:: pystencilssfg.SfgOutputMode + :members: + +.. autoclass:: pystencilssfg.SfgCodeStyle + :members: + .. autoattribute:: pystencilssfg.configuration.DEFAULT_CONFIG diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 681b6e851fffa26d94785dbdd71765953d6ba343..1ea987fdc32cc22a18c423c970d7bdfd0c8ca0cc 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -11,3 +11,4 @@ These pages provide a reference for the public API of *pystencils-sfg*. composer lang ir + errors diff --git a/docs/source/conf.py b/docs/source/conf.py index 84c2b779b553fe22fa8ef23a2f0a4872c9edb57c..3b7b20cc1efcd8d281bc18446dd7cd0aa08e3a37 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -27,6 +27,7 @@ extensions = [ "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.napoleon", + "sphinx.ext.doctest", "sphinx.ext.intersphinx", "sphinx_autodoc_typehints", "sphinx_design", @@ -56,6 +57,7 @@ intersphinx_mapping = { "numpy": ("https://docs.scipy.org/doc/numpy/", None), "matplotlib": ("https://matplotlib.org/", None), "sympy": ("https://docs.sympy.org/latest/", None), + "pystencils": ("https://da15siwa.pages.i10git.cs.fau.de/dev-docs/pystencils-nbackend/", None), } @@ -64,6 +66,13 @@ intersphinx_mapping = { autodoc_member_order = "bysource" autodoc_typehints = "description" +# Doctest Setup + +doctest_global_setup = ''' +from pystencilssfg import SfgContext, SfgComposer +sfg = SfgComposer(SfgContext()) +''' + # Prepare code generation examples diff --git a/docs/source/index.md b/docs/source/index.md index 5f74297c9213c091444caccbf85f366adaf4cb28..c65d57700d041c31de6b950d7d5b7a703b27d73d 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -8,8 +8,9 @@ usage/index api/index ``` -[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/commits/master) -[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/-/blob/master/LICENSE) +[](https://i10git.cs.fau.de/pycodegen-/pystencils-sfg/commits/master) +[](https://i10git.cs.fau.de/pycodegen-/pystencils-sfg/commits/master) +[](https://i10git.cs.fau.de/pycodegen/pystencils-sfg/-/blob/master/LICENSE) A bridge over the semantic gap between code emitted by [pystencils](https://pypi.org/project/pystencils/) and your C/C++/Cuda/HIP framework. diff --git a/pyproject.toml b/pyproject.toml index 6812c5532412ce39a466c062e6eff6c89fef7c03..cfd486622dd95163c6e226f5637cab59ba875b06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,3 +48,6 @@ versionfile_source = "src/pystencilssfg/_version.py" versionfile_build = "pystencilssfg/_version.py" tag_prefix = "v" parentdir_prefix = "pystencilssfg-" + +[tool.coverage.run] +include = ["src/pystencilssfg/*"] diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..94a3a6c5cd76967b094d0e26bab983291057461d --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +testpaths = src/pystencilssfg tests/ +python_files = "test_*.py" +# Need to ignore the generator scripts, otherwise they would be executed +# during test collection +addopts = --doctest-modules --ignore=tests/generator_scripts/scripts + +doctest_optionflags = NORMALIZE_WHITESPACE IGNORE_EXCEPTION_DETAIL diff --git a/src/pystencilssfg/__init__.py b/src/pystencilssfg/__init__.py index 247800ca5c095a808b9e0f03ccab66e40a134035..b5ac38f9c9487c6b8caf77a72cda55ea7fc1e792 100644 --- a/src/pystencilssfg/__init__.py +++ b/src/pystencilssfg/__init__.py @@ -1,17 +1,22 @@ -from .configuration import SfgConfiguration, SfgOutputMode +from .configuration import SfgConfiguration, SfgOutputMode, SfgCodeStyle from .generator import SourceFileGenerator from .composer import SfgComposer from .context import SfgContext -from .lang import AugExpr +from .lang import SfgVar, AugExpr +from .exceptions import SfgException __all__ = [ "SourceFileGenerator", "SfgComposer", "SfgConfiguration", "SfgOutputMode", + "SfgCodeStyle", "SfgContext", + "SfgVar", "AugExpr", + "SfgException", ] from . import _version -__version__ = _version.get_versions()['version'] + +__version__ = _version.get_versions()["version"] diff --git a/src/pystencilssfg/_version.py b/src/pystencilssfg/_version.py index 3ac6be9aa8bbf0b129bd1165fa7ac8d68ba69a6d..a4215b0a6436b9b268f7374186112d8715984881 100644 --- a/src/pystencilssfg/_version.py +++ b/src/pystencilssfg/_version.py @@ -1,4 +1,3 @@ - # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -68,12 +67,14 @@ HANDLERS: Dict[str, Dict[str, Callable]] = {} def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate @@ -100,10 +101,14 @@ def run_command( try: dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - process = subprocess.Popen([command] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None), **popen_kwargs) + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) break except OSError as e: if e.errno == errno.ENOENT: @@ -141,15 +146,21 @@ def versions_from_parentdir( for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -212,7 +223,7 @@ def git_versions_from_keywords( # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -221,7 +232,7 @@ def git_versions_from_keywords( # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = {r for r in refs if re.search(r'\d', r)} + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -229,32 +240,36 @@ def git_versions_from_keywords( for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] # Filter out refs that exactly match prefix or that don't start # with a number once the prefix is stripped (mostly a concern # when prefix is '') - if not re.match(r'\d', r): + if not re.match(r"\d", r): continue if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") def git_pieces_from_vcs( - tag_prefix: str, - root: str, - verbose: bool, - runner: Callable = run_command + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command ) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. @@ -273,8 +288,7 @@ def git_pieces_from_vcs( env.pop("GIT_DIR", None) runner = functools.partial(runner, env=env) - _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=not verbose) + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -282,10 +296,19 @@ def git_pieces_from_vcs( # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = runner(GITS, [ - "describe", "--tags", "--dirty", "--always", "--long", - "--match", f"{tag_prefix}[[:digit:]]*" - ], cwd=root) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -300,8 +323,7 @@ def git_pieces_from_vcs( pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None - branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], - cwd=root) + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) # --abbrev-ref was added in git-1.6.3 if rc != 0 or branch_name is None: raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") @@ -341,17 +363,16 @@ def git_pieces_from_vcs( dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparsable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -360,10 +381,12 @@ def git_pieces_from_vcs( if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -412,8 +435,7 @@ def render_pep440(pieces: Dict[str, Any]) -> str: rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -442,8 +464,7 @@ def render_pep440_branch(pieces: Dict[str, Any]) -> str: rendered = "0" if pieces["branch"] != "master": rendered += ".dev0" - rendered += "+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -604,11 +625,13 @@ def render_git_describe_long(pieces: Dict[str, Any]) -> str: def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -632,9 +655,13 @@ def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } def get_versions() -> Dict[str, Any]: @@ -648,8 +675,7 @@ def get_versions() -> Dict[str, Any]: verbose = cfg.verbose try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose) except NotThisMethod: pass @@ -658,13 +684,16 @@ def get_versions() -> Dict[str, Any]: # versionfile_source is the relative path from the top of the source # tree (where the .git directory might live) to this file. Invert # this to find the root from __file__. - for _ in cfg.versionfile_source.split('/'): + for _ in cfg.versionfile_source.split("/"): root = os.path.dirname(root) except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } try: pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) @@ -678,6 +707,10 @@ def get_versions() -> Dict[str, Any]: except NotThisMethod: pass - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/src/pystencilssfg/composer/basic_composer.py b/src/pystencilssfg/composer/basic_composer.py index 59938026ed670cb0e22c678220e7254312146eef..135f6fb85866c486626bf965bbc475c9a3accf18 100644 --- a/src/pystencilssfg/composer/basic_composer.py +++ b/src/pystencilssfg/composer/basic_composer.py @@ -1,13 +1,19 @@ from __future__ import annotations -from typing import Sequence +from typing import Sequence, TypeAlias from abc import ABC, abstractmethod import numpy as np import sympy as sp from functools import reduce from pystencils import Field -from pystencils.backend import KernelParameter, KernelFunction -from pystencils.types import create_type, UserTypeSpec, PsCustomType, PsPointerType +from pystencils.backend import KernelFunction +from pystencils.types import ( + create_type, + UserTypeSpec, + PsCustomType, + PsPointerType, + PsType, +) from ..context import SfgContext from .custom import CustomGenerator @@ -25,6 +31,7 @@ from ..ir import ( ) from ..ir.postprocessing import ( SfgDeferredParamMapping, + SfgDeferredParamSetter, SfgDeferredFieldMapping, SfgDeferredVectorMapping, ) @@ -37,9 +44,20 @@ from ..ir.source_components import ( SfgConstructor, SfgMemberVariable, SfgClassKeyword, +) +from ..lang import ( + VarLike, + ExprLike, + _VarLike, + _ExprLike, + asvar, + depends, SfgVar, + AugExpr, + SrcField, + IFieldExtraction, + SrcVector, ) -from ..lang import IFieldExtraction, SrcVector, AugExpr, SrcField from ..exceptions import SfgException @@ -53,13 +71,16 @@ class SfgIComposer(ABC): class SfgNodeBuilder(ABC): + """Base class for node builders used by the composer""" + @abstractmethod def resolve(self) -> SfgCallTreeNode: pass -ExprLike = str | SfgVar | AugExpr -SequencerArg = tuple | str | AugExpr | SfgCallTreeNode | SfgNodeBuilder +_SequencerArg = (tuple, ExprLike, SfgCallTreeNode, SfgNodeBuilder) +SequencerArg: TypeAlias = tuple | ExprLike | SfgCallTreeNode | SfgNodeBuilder +"""Valid arguments to `make_sequence` and any sequencer that uses it.""" class SfgBasicComposer(SfgIComposer): @@ -74,27 +95,87 @@ class SfgBasicComposer(SfgIComposer): The string should not contain C/C++ comment delimiters, since these will be added automatically during code generation. + + :Example: + >>> sfg.prelude("This file was generated using pystencils-sfg; do not modify it directly!") + + will appear in the generated files as + + .. code-block:: C++ + + /* + * This file was generated using pystencils-sfg; do not modify it directly! + */ + """ self._ctx.append_to_prelude(content) + def code(self, *code: str): + """Add arbitrary lines of code to the generated header file. + + :Example: + + >>> sfg.code( + ... "#define PI 3.14 // more than enough for engineers", + ... "using namespace std;" + ... ) + + will appear as + + .. code-block:: C++ + + #define PI 3.14 // more than enough for engineers + using namespace std; + + """ + for c in code: + self._ctx.add_definition(c) + def define(self, *definitions: str): - """Add custom definitions to the generated header file.""" - for d in definitions: - self._ctx.add_definition(d) + from warnings import warn + + warn( + "The `define` method of `SfgBasicComposer` is deprecated and will be removed in a future version." + "Use `sfg.code()` instead.", + FutureWarning, + ) + + self.code(*definitions) def define_once(self, *definitions: str): - """Same as `define`, but only adds definitions only if the same code string was not already added.""" + """Add unique definitions to the header file. + + Each code string given to `define_once` will only be added if the exact same string + was not already added before. + """ for definition in definitions: if all(d != definition for d in self._ctx.definitions()): self._ctx.add_definition(definition) def namespace(self, namespace: str): - """Set the inner code namespace. Throws an exception if a namespace was already set.""" + """Set the inner code namespace. Throws an exception if a namespace was already set. + + :Example: + + After adding the following to your generator script: + + >>> sfg.namespace("codegen_is_awesome") + + All generated code will be placed within that namespace: + + .. code-block:: C++ + + namespace codegen_is_awesome { + /* all generated code */ + } + """ self._ctx.set_namespace(namespace) def generate(self, generator: CustomGenerator): """Invoke a custom code generator with the underlying context.""" - generator.generate(self._ctx) + from .composer import SfgComposer + + generator.generate(SfgComposer(self)) @property def kernels(self) -> SfgKernelNamespace: @@ -120,9 +201,21 @@ class SfgBasicComposer(SfgIComposer): """Include a header file. Args: - header_file: Path to the header file. Enclose in `<>` for a system header. - private: If `True`, in header-implementation code generation, the header file is + header_file: Path to the header file. Enclose in ``<>`` for a system header. + private: If ``True``, in header-implementation code generation, the header file is only included in the implementation file. + + :Example: + + >>> sfg.include("<vector>") + >>> sfg.include("custom.h") + + will be printed as + + .. code-block:: C++ + + #include <vector> + #include "custom.h" """ self._ctx.add_include(SfgHeaderInclude.parse(header_file, private)) @@ -137,7 +230,7 @@ class SfgBasicComposer(SfgIComposer): if self._ctx.get_class(name) is not None: raise SfgException(f"Class with name {name} already exists.") - cls = struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) + cls = _struct_from_numpy_dtype(name, dtype, add_constructor=add_constructor) self._ctx.add_class(cls) return cls @@ -179,7 +272,7 @@ class SfgBasicComposer(SfgIComposer): if self._ctx.get_function(name) is not None: raise ValueError(f"Function {name} already exists.") - def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + def sequencer(*args: SequencerArg): tree = make_sequence(*args) func = SfgFunction(name, tree) self._ctx.add_function(func) @@ -208,18 +301,22 @@ class SfgBasicComposer(SfgIComposer): num_blocks_str = str(num_blocks) tpb_str = str(threads_per_block) stream_str = str(stream) if stream is not None else None - depends = _depends(num_blocks) | _depends(threads_per_block) | _depends(stream) + + deps = depends(num_blocks) | depends(threads_per_block) + if stream is not None: + deps |= depends(stream) + return SfgCudaKernelInvocation( - kernel_handle, num_blocks_str, tpb_str, stream_str, depends + kernel_handle, num_blocks_str, tpb_str, stream_str, deps ) def seq(self, *args: tuple | str | SfgCallTreeNode | SfgNodeBuilder) -> SfgSequence: """Syntax sequencing. For details, see `make_sequence`""" return make_sequence(*args) - def params(self, *args: SfgVar) -> SfgFunctionParams: + def params(self, *args: AugExpr) -> SfgFunctionParams: """Use inside a function body to add parameters to the function.""" - return SfgFunctionParams(args) + return SfgFunctionParams([x.as_variable() for x in args]) def require(self, *includes: str | SfgHeaderInclude) -> SfgRequireIncludes: return SfgRequireIncludes( @@ -232,7 +329,7 @@ class SfgBasicComposer(SfgIComposer): ptr: bool = False, ref: bool = False, const: bool = False, - ): + ) -> PsType: if ptr and ref: raise SfgException("Create either a pointer, or a ref type, not both!") @@ -250,11 +347,23 @@ class SfgBasicComposer(SfgIComposer): else: return base_type - def var(self, name: str, dtype: UserTypeSpec) -> SfgVar: + def var(self, name: str, dtype: UserTypeSpec) -> AugExpr: """Create a variable with given name and data type.""" - return SfgVar(name, create_type(dtype)) + return AugExpr(create_type(dtype)).var(name) - def init(self, lhs: SfgVar) -> SfgInplaceInitBuilder: + def vars(self, names: str, dtype: UserTypeSpec) -> tuple[AugExpr, ...]: + """Create multiple variables with given names and the same data type. + + Example: + + >>> sfg.vars("x, y, z", "float32") + (x, y, z) + + """ + varnames = names.split(",") + return tuple(self.var(n.strip(), dtype) for n in varnames) + + def init(self, lhs: VarLike): """Create a C++ in-place initialization. Usage: @@ -270,9 +379,51 @@ class SfgBasicComposer(SfgIComposer): SomeClass obj { arg1, arg2, arg3 }; """ - return SfgInplaceInitBuilder(lhs) + lhs_var = asvar(lhs) + + def parse_args(*args: ExprLike): + args_str = ", ".join(str(arg) for arg in args) + deps: set[SfgVar] = reduce(set.union, (depends(arg) for arg in args), set()) + return SfgStatements( + f"{lhs_var.dtype} {lhs_var.name} {{ {args_str} }};", + (lhs_var,), + deps, + ) + + return parse_args + + def expr(self, fmt: str, *deps, **kwdeps) -> AugExpr: + """Create an expression while keeping track of variables it depends on. + + This method is meant to be used similarly to `str.format`; in fact, + it calls `str.format` internally and therefore supports all of its + formatting features. + In addition, however, the format arguments are scanned for *variables* + (e.g. created using `var`), which are attached to the expression. + This way, *pystencils-sfg* keeps track of any variables an expression depends on. + + :Example: - def expr(self, fmt: str, *deps, **kwdeps): + >>> x, y, z, w = sfg.vars("x, y, z, w", "float32") + >>> expr = sfg.expr("{} + {} * {}", x, y, z) + >>> expr + x + y * z + + You can look at the expression's dependencies: + + >>> sorted(expr.depends, key=lambda v: v.name) + [x: float, y: float, z: float] + + If you use an existing expression to create a larger one, the new expression + inherits all variables from its parts: + + >>> expr2 = sfg.expr("{} + {}", expr, w) + >>> expr2 + x + y * z + w + >>> sorted(expr2.depends, key=lambda v: v.name) + [w: float, x: float, y: float, z: float] + + """ return AugExpr.format(fmt, *deps, **kwdeps) @property @@ -306,38 +457,47 @@ class SfgBasicComposer(SfgIComposer): """ return SfgDeferredFieldMapping(field, index_provider) + def set_param(self, param: VarLike | sp.Symbol, expr: ExprLike): + deps = depends(expr) + var: SfgVar | sp.Symbol = asvar(param) if isinstance(param, _VarLike) else param + return SfgDeferredParamSetter(var, deps, str(expr)) + def map_param( self, - lhs: SfgVar, - rhs: SfgVar | Sequence[SfgVar], + param: VarLike | sp.Symbol, + depends: VarLike | Sequence[VarLike], mapping: str, ): - """Arbitrary parameter mapping: Add a single line of code to define a left-hand - side object from one or multiple right-hand side dependencies.""" - if isinstance(rhs, (KernelParameter, SfgVar)): - rhs = [rhs] - return SfgDeferredParamMapping(lhs, set(rhs), mapping) + from warnings import warn + + warn( + "The `map_param` method of `SfgBasicComposer` is deprecated and will be removed " + "in a future version. Use `sfg.set_param` instead.", + FutureWarning, + ) - def map_vector(self, lhs_components: Sequence[SfgVar | sp.Symbol], rhs: SrcVector): + if isinstance(depends, _VarLike): + depends = [depends] + lhs_var: SfgVar | sp.Symbol = ( + asvar(param) if isinstance(param, _VarLike) else param + ) + return SfgDeferredParamMapping(lhs_var, set(asvar(v) for v in depends), mapping) + + def map_vector(self, lhs_components: Sequence[VarLike | sp.Symbol], rhs: SrcVector): """Extracts scalar numerical values from a vector data type. Args: lhs_components: Vector components as a list of symbols. rhs: A `SrcVector` object representing a vector data structure. """ - return SfgDeferredVectorMapping(lhs_components, rhs) + components: list[SfgVar | sp.Symbol] = [ + (asvar(c) if isinstance(c, _VarLike) else c) for c in lhs_components + ] + return SfgDeferredVectorMapping(components, rhs) def make_statements(arg: ExprLike) -> SfgStatements: - match arg: - case str(): - return SfgStatements(arg, (), ()) - case SfgVar(name, _): - return SfgStatements(name, (), (arg,)) - case AugExpr(): - return SfgStatements(str(arg), (), arg.depends) - case _: - assert False + return SfgStatements(str(arg), (), depends(arg)) def make_sequence(*args: SequencerArg) -> SfgSequence: @@ -354,37 +514,37 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: - Sub-ASTs and AST builders, which are often produced by the syntactic sugar and factory methods of `SfgComposer`. - Its usage is best shown by example: + :Example: - .. code-block:: Python + .. code-block:: Python - tree = make_sequence( - "int a = 0;", - "int b = 1;", - ( - "int tmp = b;", - "b = a;", - "a = tmp;" - ), - SfgKernelCall(kernel_handle) - ) + tree = make_sequence( + "int a = 0;", + "int b = 1;", + ( + "int tmp = b;", + "b = a;", + "a = tmp;" + ), + SfgKernelCall(kernel_handle) + ) - sfg.context.add_function("myFunction", tree) + sfg.context.add_function("myFunction", tree) - will translate to + will translate to - .. code-block:: C++ + .. code-block:: C++ - void myFunction() { - int a = 0; - int b = 0; - { - int tmp = b; - b = a; - a = tmp; + void myFunction() { + int a = 0; + int b = 0; + { + int tmp = b; + b = a; + a = tmp; + } + kernels::kernel( ... ); } - kernels::kernel( ... ); - } """ children = [] for i, arg in enumerate(args): @@ -392,10 +552,8 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: children.append(arg.resolve()) elif isinstance(arg, SfgCallTreeNode): children.append(arg) - elif isinstance(arg, AugExpr): - children.append(SfgStatements(str(arg), (), arg.depends)) - elif isinstance(arg, str): - children.append(SfgStatements(arg, (), ())) + elif isinstance(arg, _ExprLike): + children.append(make_statements(arg)) elif isinstance(arg, tuple): # Tuples are treated as blocks subseq = make_sequence(*arg) @@ -406,35 +564,9 @@ def make_sequence(*args: SequencerArg) -> SfgSequence: return SfgSequence(children) -class SfgInplaceInitBuilder(SfgNodeBuilder): - def __init__(self, lhs: SfgVar) -> None: - self._lhs: SfgVar = lhs - self._depends: set[SfgVar] = set() - self._rhs: str | None = None - - def __call__( - self, - *rhs: str | AugExpr, - ) -> SfgInplaceInitBuilder: - if self._rhs is not None: - raise SfgException("Assignment builder used multiple times.") - - self._rhs = ", ".join(str(expr) for expr in rhs) - self._depends = reduce( - set.union, (obj.depends for obj in rhs if isinstance(obj, AugExpr)), set() - ) - return self - - def resolve(self) -> SfgCallTreeNode: - assert self._rhs is not None - return SfgStatements( - f"{self._lhs.dtype} {self._lhs.name} {{ {self._rhs} }};", - [self._lhs], - self._depends, - ) - - class SfgBranchBuilder(SfgNodeBuilder): + """Multi-call builder for C++ ``if/else`` statements.""" + def __init__(self) -> None: self._phase = 0 @@ -471,6 +603,8 @@ class SfgBranchBuilder(SfgNodeBuilder): class SfgSwitchBuilder(SfgNodeBuilder): + """Builder for C++ switches.""" + def __init__(self, switch_arg: ExprLike): self._switch_arg = switch_arg self._cases: dict[str, SfgSequence] = dict() @@ -505,7 +639,7 @@ class SfgSwitchBuilder(SfgNodeBuilder): return SfgSwitch(make_statements(self._switch_arg), self._cases, self._default) -def struct_from_numpy_dtype( +def _struct_from_numpy_dtype( struct_name: str, dtype: np.dtype, add_constructor: bool = True ): cls = SfgClass(struct_name, class_keyword=SfgClassKeyword.STRUCT) @@ -533,15 +667,3 @@ def struct_from_numpy_dtype( cls.default.append_member(SfgConstructor(constr_params, constr_inits)) return cls - - -def _depends(expr: ExprLike | Sequence[ExprLike] | None) -> set[SfgVar]: - match expr: - case None | str(): - return set() - case SfgVar(): - return {expr} - case AugExpr(): - return expr.depends - case _: - raise ValueError(f"Invalid expression: {expr}") diff --git a/src/pystencilssfg/composer/class_composer.py b/src/pystencilssfg/composer/class_composer.py index 63588b7829b2a94122e5c9d7d38770cabce9d6f5..bd906782d6e074955c60221d9a8f4d9b15bae772 100644 --- a/src/pystencilssfg/composer/class_composer.py +++ b/src/pystencilssfg/composer/class_composer.py @@ -3,7 +3,13 @@ from typing import Sequence from pystencils.types import PsCustomType, UserTypeSpec -from ..ir import SfgCallTreeNode +from ..lang import ( + _VarLike, + VarLike, + ExprLike, + asvar, +) + from ..ir.source_components import ( SfgClass, SfgClassMember, @@ -14,12 +20,14 @@ from ..ir.source_components import ( SfgClassKeyword, SfgVisibility, SfgVisibilityBlock, - SfgVar, ) from ..exceptions import SfgException from .mixin import SfgComposerMixIn -from .basic_composer import SfgNodeBuilder, make_sequence +from .basic_composer import ( + make_sequence, + SequencerArg, +) class SfgClassComposer(SfgComposerMixIn): @@ -46,7 +54,7 @@ class SfgClassComposer(SfgComposerMixIn): def __call__( self, *args: ( - SfgClassMember | SfgClassComposer.ConstructorBuilder | SfgVar | str + SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str ), ): for arg in args: @@ -63,15 +71,21 @@ class SfgClassComposer(SfgComposerMixIn): Returned by `constructor`. """ - def __init__(self, *params: SfgVar): - self._params = params + def __init__(self, *params: VarLike): + self._params = tuple(asvar(p) for p in params) self._initializers: list[str] = [] self._body: str | None = None - def init(self, initializer: str) -> SfgClassComposer.ConstructorBuilder: + def init(self, var: VarLike): """Add an initialization expression to the constructor's initializer list.""" - self._initializers.append(initializer) - return self + + def init_sequencer(*args: ExprLike): + expr = ", ".join(str(arg) for arg in args) + initializer = f"{asvar(var)}{{ {expr} }}" + self._initializers.append(initializer) + return self + + return init_sequencer def body(self, body: str): """Define the constructor body""" @@ -120,7 +134,7 @@ class SfgClassComposer(SfgComposerMixIn): """Create a `private` visibility block in a class or struct body""" return SfgClassComposer.VisibilityContext(SfgVisibility.PRIVATE) - def constructor(self, *params: SfgVar): + def constructor(self, *params: VarLike): """In a class or struct body or visibility block, add a constructor. Args: @@ -145,7 +159,7 @@ class SfgClassComposer(SfgComposerMixIn): const: Whether or not the method is const-qualified. """ - def sequencer(*args: str | tuple | SfgCallTreeNode | SfgNodeBuilder): + def sequencer(*args: SequencerArg): tree = make_sequence(*args) return SfgMethod( name, @@ -171,7 +185,7 @@ class SfgClassComposer(SfgComposerMixIn): SfgClassComposer.VisibilityContext | SfgClassMember | SfgClassComposer.ConstructorBuilder - | SfgVar + | VarLike | str ), ): @@ -186,9 +200,9 @@ class SfgClassComposer(SfgComposerMixIn): ( SfgClassMember, SfgClassComposer.ConstructorBuilder, - SfgVar, str, - ), + ) + + _VarLike, ): if default_ended: raise SfgException( @@ -204,13 +218,17 @@ class SfgClassComposer(SfgComposerMixIn): @staticmethod def _resolve_member( - arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | SfgVar | str, - ): - if isinstance(arg, SfgVar): - return SfgMemberVariable(arg.name, arg.dtype) - elif isinstance(arg, str): - return SfgInClassDefinition(arg) - elif isinstance(arg, SfgClassComposer.ConstructorBuilder): - return arg.resolve() - else: - return arg + arg: SfgClassMember | SfgClassComposer.ConstructorBuilder | VarLike | str, + ) -> SfgClassMember: + match arg: + case _ if isinstance(arg, _VarLike): + var = asvar(arg) + return SfgMemberVariable(var.name, var.dtype) + case str(): + return SfgInClassDefinition(arg) + case SfgClassComposer.ConstructorBuilder(): + return arg.resolve() + case SfgClassMember(): + return arg + case _: + raise ValueError(f"Invalid class member: {arg}") diff --git a/src/pystencilssfg/composer/custom.py b/src/pystencilssfg/composer/custom.py index 26e9b933456904576202837b763fa09a8ae6b141..7df364c6cd78c1a56d68283f3c617092938a4dcf 100644 --- a/src/pystencilssfg/composer/custom.py +++ b/src/pystencilssfg/composer/custom.py @@ -1,10 +1,14 @@ +from __future__ import annotations from abc import ABC, abstractmethod -from ..context import SfgContext +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .composer import SfgComposer class CustomGenerator(ABC): """Abstract base class for custom code generators that may be passed to - [SfgComposer.generate][pystencilssfg.SfgComposer.generate].""" + `SfgComposer.generate`.""" @abstractmethod - def generate(self, ctx: SfgContext) -> None: ... + def generate(self, sfg: SfgComposer) -> None: ... diff --git a/src/pystencilssfg/configuration.py b/src/pystencilssfg/configuration.py index 2eb3efe7855af24cfe21e2815d7a40d6eea60686..a76251a04cde0d875b0ba41908a0fadc5fc90553 100644 --- a/src/pystencilssfg/configuration.py +++ b/src/pystencilssfg/configuration.py @@ -40,7 +40,10 @@ class SfgCodeStyle: """ force_clang_format: bool = False - """If set to True, abort code generation if `clang-format` binary cannot be found.""" + """If set to True, abort code generation if ``clang-format`` binary cannot be found.""" + + skip_clang_format: bool = False + """If set to True, skip formatting using ``clang-format``.""" clang_format_binary: str = "clang-format" """Path to the clang-format executable""" diff --git a/src/pystencilssfg/context.py b/src/pystencilssfg/context.py index 69ad83f8e2ab8ddf9436e1ebc3505b819d6a5cf2..bd3591889cbed6ab2a4bc023787944857585a601 100644 --- a/src/pystencilssfg/context.py +++ b/src/pystencilssfg/context.py @@ -69,7 +69,7 @@ class SfgContext: # Source Components self._prelude: str = "" - self._includes: set[SfgHeaderInclude] = set() + self._includes: list[SfgHeaderInclude] = [] self._definitions: list[str] = [] self._kernel_namespaces = { self._default_kernel_namespace.name: self._default_kernel_namespace @@ -79,10 +79,6 @@ class SfgContext: self._declarations_ordered: list[str | SfgFunction | SfgClass] = list() - # Standard stuff - self.add_include(SfgHeaderInclude("cstdint", system_header=True)) - self.add_definition("#define RESTRICT __restrict__") - @property def argv(self) -> Sequence[str]: """If this context was created by a `pystencilssfg.SourceFileGenerator`, provides the command @@ -159,7 +155,7 @@ class SfgContext: yield from self._includes def add_include(self, include: SfgHeaderInclude): - self._includes.add(include) + self._includes.append(include) def definitions(self) -> Generator[str, None, None]: """Definitions are arbitrary custom lines of code.""" diff --git a/src/pystencilssfg/emission/clang_format.py b/src/pystencilssfg/emission/clang_format.py index 5c5084ce340d17cd0b11f3b58091df99db4d46c9..eea152a062474a6761fdbbdecfaaeb88bb63c4d3 100644 --- a/src/pystencilssfg/emission/clang_format.py +++ b/src/pystencilssfg/emission/clang_format.py @@ -24,6 +24,9 @@ def invoke_clang_format(code: str, codestyle: SfgCodeStyle) -> str: be executed (binary not found, or error during exection), the function will throw an exception. """ + if codestyle.skip_clang_format: + return code + args = [codestyle.clang_format_binary, f"--style={codestyle.code_style}"] if not shutil.which(codestyle.clang_format_binary): diff --git a/src/pystencilssfg/emission/printers.py b/src/pystencilssfg/emission/printers.py index 1b3a805cdb3efbac12e81005b7747859d0e4e2cb..93371619e302932d97441ff069190ba956b0f04d 100644 --- a/src/pystencilssfg/emission/printers.py +++ b/src/pystencilssfg/emission/printers.py @@ -178,14 +178,6 @@ class SfgHeaderPrinter(SfgGeneralPrinter): return code -def delimiter(content): - return f"""\ -/************************************************************************************* - * {content} -*************************************************************************************/ -""" - - class SfgImplPrinter(SfgGeneralPrinter): def __init__( self, ctx: SfgContext, output_spec: SfgOutputSpec, inline_impl: bool = False @@ -219,11 +211,8 @@ class SfgImplPrinter(SfgGeneralPrinter): parts = interleave( chain( - [delimiter("Kernels")], ctx.kernel_namespaces(), - [delimiter("Functions")], ctx.functions(), - [delimiter("Class Methods")], ctx.classes(), ), repeat(SfgEmptyLines(1)), diff --git a/src/pystencilssfg/extensions/sycl.py b/src/pystencilssfg/extensions/sycl.py index cc3f83fa6a049febd76d1e62dcdda39d4a502575..af59f4f82b1451928b3e3bbdb6835d4ce92f33c5 100644 --- a/src/pystencilssfg/extensions/sycl.py +++ b/src/pystencilssfg/extensions/sycl.py @@ -16,14 +16,14 @@ from ..composer import ( SfgComposerMixIn, ) from ..ir.source_components import SfgKernelHandle, SfgHeaderInclude -from ..ir.source_components import SfgVar, SfgSymbolLike +from ..ir.source_components import SfgSymbolLike from ..ir import ( SfgCallTreeNode, SfgCallTreeLeaf, SfgKernelCallNode, ) -from ..lang import AugExpr +from ..lang import SfgVar, AugExpr class SyclComposerMixIn(SfgComposerMixIn): diff --git a/src/pystencilssfg/generator.py b/src/pystencilssfg/generator.py index 22db7e2cc54d0670bd4a378dae88c95559d04e18..aa8396c4ecae7471c333a14587435ac3f35eec3a 100644 --- a/src/pystencilssfg/generator.py +++ b/src/pystencilssfg/generator.py @@ -58,6 +58,11 @@ class SourceFileGenerator: project_info=config.project_info, ) + from pystencilssfg.ir import SfgHeaderInclude + + self._context.add_include(SfgHeaderInclude("cstdint", system_header=True)) + self._context.add_definition("#define RESTRICT __restrict__") + self._emitter: AbstractEmitter match config.output_mode: case SfgOutputMode.HEADER_ONLY: diff --git a/src/pystencilssfg/ir/__init__.py b/src/pystencilssfg/ir/__init__.py index 43660069ce5e049f60b42193698427f1756d2fae..1ae1749367527472aaf5ba77ddf09a8081ed578b 100644 --- a/src/pystencilssfg/ir/__init__.py +++ b/src/pystencilssfg/ir/__init__.py @@ -19,7 +19,6 @@ from .source_components import ( SfgEmptyLines, SfgKernelNamespace, SfgKernelHandle, - SfgVar, SfgSymbolLike, SfgFunction, SfgVisibility, @@ -51,7 +50,6 @@ __all__ = [ "SfgEmptyLines", "SfgKernelNamespace", "SfgKernelHandle", - "SfgVar", "SfgSymbolLike", "SfgFunction", "SfgVisibility", diff --git a/src/pystencilssfg/ir/call_tree.py b/src/pystencilssfg/ir/call_tree.py index 34d50182b7bca8348a161553e4f7edaa4ae3c0d0..4a084649b068487a5e712fb1d7f49f87153bc5ee 100644 --- a/src/pystencilssfg/ir/call_tree.py +++ b/src/pystencilssfg/ir/call_tree.py @@ -4,7 +4,8 @@ from typing import TYPE_CHECKING, Sequence, Iterable, NewType from abc import ABC, abstractmethod from itertools import chain -from .source_components import SfgHeaderInclude, SfgKernelHandle, SfgVar +from .source_components import SfgHeaderInclude, SfgKernelHandle +from ..lang import SfgVar if TYPE_CHECKING: from ..context import SfgContext @@ -123,6 +124,10 @@ class SfgStatements(SfgCallTreeLeaf): def required_includes(self) -> set[SfgHeaderInclude]: return self._required_includes + @property + def code_string(self) -> str: + return self._code_string + def get_code(self, ctx: SfgContext) -> str: return self._code_string diff --git a/src/pystencilssfg/ir/postprocessing.py b/src/pystencilssfg/ir/postprocessing.py index c30910d59f6c6626dd6fc147f93491e9b175dace..851a981862cb900442d92624bbaef0dbece68f89 100644 --- a/src/pystencilssfg/ir/postprocessing.py +++ b/src/pystencilssfg/ir/postprocessing.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Sequence, Iterable import warnings from functools import reduce from dataclasses import dataclass @@ -9,6 +9,7 @@ from abc import ABC, abstractmethod import sympy as sp from pystencils import Field, TypedSymbol +from pystencils.types import deconstify from pystencils.backend.kernelfunction import ( FieldPointerParam, FieldShapeParam, @@ -18,8 +19,8 @@ from pystencils.backend.kernelfunction import ( from ..exceptions import SfgException from .call_tree import SfgCallTreeNode, SfgCallTreeLeaf, SfgSequence, SfgStatements -from ..ir.source_components import SfgVar, SfgSymbolLike -from ..lang import IFieldExtraction, SrcField, SrcVector +from ..ir.source_components import SfgSymbolLike +from ..lang import SfgVar, IFieldExtraction, SrcField, SrcVector if TYPE_CHECKING: from ..context import SfgContext @@ -61,7 +62,7 @@ class FlattenSequences: class PostProcessingContext: def __init__(self, enclosing_class: SfgClass | None = None) -> None: self.enclosing_class: SfgClass | None = enclosing_class - self.live_objects: set[SfgVar] = set() + self._live_variables: dict[str, SfgVar] = dict() def is_method(self) -> bool: return self.enclosing_class is not None @@ -72,6 +73,65 @@ class PostProcessingContext: return self.enclosing_class + @property + def live_variables(self) -> set[SfgVar]: + return set(self._live_variables.values()) + + def get_live_variable(self, name: str) -> SfgVar | None: + return self._live_variables.get(name) + + def _define(self, vars: Iterable[SfgVar], expr: str): + for var in vars: + if var.name in self._live_variables: + live_var = self._live_variables[var.name] + + live_var_dtype = live_var.dtype + def_dtype = var.dtype + + # A const definition conflicts with a non-const live variable + # A non-const definition is always OK, but then the types must be the same + if (def_dtype.const and not live_var_dtype.const) or ( + deconstify(def_dtype) != deconstify(live_var_dtype) + ): + warnings.warn( + f"Type conflict at variable definition: Expected type {live_var_dtype}, but got {def_dtype}.\n" + f" * At definition {expr}", + UserWarning, + ) + + del self._live_variables[var.name] + + def _use(self, vars: Iterable[SfgVar]): + for var in vars: + if var.name in self._live_variables: + live_var = self._live_variables[var.name] + + if var != live_var: + if var.dtype == live_var.dtype: + # This can only happen if the variables are SymbolLike, + # i.e. wrap a field-associated kernel parameter + # TODO: Once symbol properties are a thing, check and combine them here + warnings.warn( + "Encountered two non-identical variables with same name and data type:\n" + f" {var.name_and_type()}\n" + "and\n" + f" {live_var.name_and_type()}\n" + ) + elif deconstify(var.dtype) == deconstify(live_var.dtype): + # Same type, just different constness + # One of them must be non-const -> keep the non-const one + if live_var.dtype.const and not var.dtype.const: + self._live_variables[var.name] = var + else: + raise SfgException( + "Encountered two variables with same name but different data types:\n" + f" {var.name_and_type()}\n" + "and\n" + f" {live_var.name_and_type()}" + ) + else: + self._live_variables[var.name] = var + @dataclass(frozen=True) class PostProcessingResult: @@ -84,30 +144,8 @@ class CallTreePostProcessing: self._flattener = FlattenSequences() def __call__(self, ast: SfgCallTreeNode) -> PostProcessingResult: - params = self.get_live_objects(ast) - params_by_name: dict[str, SfgVar] = dict() - - for param in params: - if param.name in params_by_name: - other = params_by_name[param.name] - - if param.dtype == other.dtype: - warnings.warn( - "Encountered two non-identical parameters with same name and data type:\n" - f" {repr(param)}\n" - "and\n" - f" {repr(other)}\n" - ) - else: - raise SfgException( - "Encountered two parameters with same name but different data types:\n" - f" {repr(param)}\n" - "and\n" - f" {repr(other)}" - ) - params_by_name[param.name] = param - - return PostProcessingResult(set(params_by_name.values())) + live_vars = self.get_live_variables(ast) + return PostProcessingResult(live_vars) def handle_sequence(self, seq: SfgSequence, ppc: PostProcessingContext): def iter_nested_sequences(seq: SfgSequence): @@ -122,18 +160,18 @@ class CallTreePostProcessing: iter_nested_sequences(c) else: if isinstance(c, SfgStatements): - ppc.live_objects -= c.defines + ppc._define(c.defines, c.code_string) - ppc.live_objects |= self.get_live_objects(c) + ppc._use(self.get_live_variables(c)) iter_nested_sequences(seq) - def get_live_objects(self, node: SfgCallTreeNode) -> set[SfgVar]: + def get_live_variables(self, node: SfgCallTreeNode) -> set[SfgVar]: match node: case SfgSequence(): ppc = self._ppc() self.handle_sequence(node, ppc) - return ppc.live_objects + return ppc.live_variables case SfgCallTreeLeaf(): return node.depends @@ -144,7 +182,7 @@ class CallTreePostProcessing: case _: return reduce( lambda x, y: x | y, - (self.get_live_objects(c) for c in node.children), + (self.get_live_variables(c) for c in node.children), set(), ) @@ -177,14 +215,30 @@ class SfgDeferredNode(SfgCallTreeNode, ABC): class SfgDeferredParamMapping(SfgDeferredNode): - def __init__(self, lhs: SfgVar, rhs: set[SfgVar], mapping: str): + def __init__(self, lhs: SfgVar | sp.Symbol, depends: set[SfgVar], mapping: str): self._lhs = lhs - self._rhs = rhs + self._depends = depends self._mapping = mapping def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: - if self._lhs in ppc.live_objects: - return SfgStatements(self._mapping, (self._lhs,), tuple(self._rhs)) + live_var = ppc.get_live_variable(self._lhs.name) + if live_var is not None: + return SfgStatements(self._mapping, (live_var,), tuple(self._depends)) + else: + return SfgSequence([]) + + +class SfgDeferredParamSetter(SfgDeferredNode): + def __init__(self, param: SfgVar | sp.Symbol, depends: set[SfgVar], rhs_expr: str): + self._lhs = param + self._depends = depends + self._rhs_expr = rhs_expr + + def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: + live_var = ppc.get_live_variable(self._lhs.name) + if live_var is not None: + code = f"{live_var.dtype} {live_var.name} = {self._rhs_expr};" + return SfgStatements(code, (live_var,), tuple(self._depends)) else: return SfgSequence([]) @@ -209,7 +263,7 @@ class SfgDeferredFieldMapping(SfgDeferredNode): self._field.strides ) - for param in ppc.live_objects: + for param in ppc.live_variables: # idk why, but mypy does not understand these pattern matches match param: case SfgSymbolLike(FieldPointerParam(_, _, field)) if field == self._field: # type: ignore @@ -288,7 +342,7 @@ class SfgDeferredVectorMapping(SfgDeferredNode): def expand(self, ppc: PostProcessingContext) -> SfgCallTreeNode: nodes = [] - for param in ppc.live_objects: + for param in ppc.live_variables: if param.name in self._scalars: idx, _ = self._scalars[param.name] expr = self._vector.extract_component(idx) diff --git a/src/pystencilssfg/ir/source_components.py b/src/pystencilssfg/ir/source_components.py index 06788cf5918c8dddd3aa0ca687a938d3722879ae..8a4e90967d74aabc4f50529c7ff0e51c76a3d869 100644 --- a/src/pystencilssfg/ir/source_components.py +++ b/src/pystencilssfg/ir/source_components.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC from enum import Enum, auto -from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic, Any +from typing import TYPE_CHECKING, Sequence, Generator, TypeVar, Generic from dataclasses import replace from itertools import chain @@ -14,6 +14,7 @@ from pystencils.backend.kernelfunction import ( ) from pystencils.types import PsType, PsCustomType +from ..lang import SfgVar from ..exceptions import SfgException if TYPE_CHECKING: @@ -31,6 +32,7 @@ class SfgEmptyLines: class SfgHeaderInclude: + """Represent ``#include``-directives.""" @staticmethod def parse(incl: str | SfgHeaderInclude, private: bool = False): @@ -197,59 +199,12 @@ class SfgKernelHandle: @property def fields(self): - return self.fields + return self._fields def get_kernel_function(self) -> KernelFunction: return self._namespace.get_kernel_function(self) -class SfgVar: - __match_args__ = ("name", "dtype") - - def __init__( - self, - name: str, - dtype: PsType, - required_includes: set[SfgHeaderInclude] | None = None, - ): - self._name = name - self._dtype = dtype - - self._required_includes = ( - required_includes if required_includes is not None else set() - ) - - @property - def name(self) -> str: - return self._name - - @property - def dtype(self) -> PsType: - return self._dtype - - def _args(self) -> tuple[Any, ...]: - return (self._name, self._dtype) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, SfgVar): - return False - - return self._args() == other._args() - - def __hash__(self) -> int: - return hash(self._args()) - - @property - def required_includes(self) -> set[SfgHeaderInclude]: - return self._required_includes - - def __str__(self) -> str: - return self._name - - def __repr__(self) -> str: - return f"SfgVar( {self._name}, {repr(self._dtype)} )" - - SymbolLike_T = TypeVar("SymbolLike_T", bound=KernelParameter) @@ -517,7 +472,7 @@ class SfgClass: self._definitions: list[SfgInClassDefinition] = [] self._constructors: list[SfgConstructor] = [] - self._methods: dict[str, SfgMethod] = dict() + self._methods: list[SfgMethod] = [] self._member_vars: dict[str, SfgMemberVariable] = dict() @property @@ -595,11 +550,9 @@ class SfgClass: self, visibility: SfgVisibility | None = None ) -> Generator[SfgMethod, None, None]: if visibility is not None: - yield from filter( - lambda m: m.visibility == visibility, self._methods.values() - ) + yield from filter(lambda m: m.visibility == visibility, self._methods) else: - yield from self._methods.values() + yield from self._methods # PRIVATE @@ -621,16 +574,10 @@ class SfgClass: self._definitions.append(definition) def _add_constructor(self, constr: SfgConstructor): - # TODO: Check for signature conflicts? self._constructors.append(constr) def _add_method(self, method: SfgMethod): - if method.name in self._methods: - raise SfgException( - f"Duplicate method name {method.name} in class {self._class_name}" - ) - - self._methods[method.name] = method + self._methods.append(method) def _add_member_variable(self, variable: SfgMemberVariable): if variable.name in self._member_vars: diff --git a/src/pystencilssfg/lang/__init__.py b/src/pystencilssfg/lang/__init__.py index 543d3094131f7dd4d0d8edf1673c2e5a27597065..d67ffa0c845b16c867064365cec8b5af5ffe6a2a 100644 --- a/src/pystencilssfg/lang/__init__.py +++ b/src/pystencilssfg/lang/__init__.py @@ -1,15 +1,30 @@ from .expressions import ( - DependentExpression, + SfgVar, AugExpr, + VarLike, + _VarLike, + ExprLike, + _ExprLike, + asvar, + depends, IFieldExtraction, SrcField, SrcVector, ) +from .types import Ref + __all__ = [ - "DependentExpression", + "SfgVar", "AugExpr", + "VarLike", + "_VarLike", + "ExprLike", + "_ExprLike", + "asvar", + "depends", "IFieldExtraction", "SrcField", "SrcVector", + "Ref", ] diff --git a/src/pystencilssfg/lang/expressions.py b/src/pystencilssfg/lang/expressions.py index c456194a11fd6da541c96b76f2bf1e5b3a8bc984..481922e728549e48550e91b562b6099ec9b8c094 100644 --- a/src/pystencilssfg/lang/expressions.py +++ b/src/pystencilssfg/lang/expressions.py @@ -1,13 +1,91 @@ from __future__ import annotations -from typing import Iterable +from typing import Iterable, TypeAlias, Any, TYPE_CHECKING from itertools import chain from abc import ABC, abstractmethod -from pystencils.types import PsType +import sympy as sp + +from pystencils import TypedSymbol +from pystencils.types import PsType, UserTypeSpec, create_type -from ..ir.source_components import SfgVar, SfgHeaderInclude from ..exceptions import SfgException +if TYPE_CHECKING: + from ..ir.source_components import SfgHeaderInclude + + +__all__ = [ + "SfgVar", + "AugExpr", + "VarLike", + "ExprLike", + "asvar", + "depends", + "IFieldExtraction", + "SrcField", + "SrcVector", +] + + +class SfgVar: + """C++ Variable. + + Args: + name: Name of the variable. Must be a valid C++ identifer. + dtype: Data type of the variable. + """ + + __match_args__ = ("name", "dtype") + + def __init__( + self, + name: str, + dtype: UserTypeSpec, + required_includes: set[SfgHeaderInclude] | None = None, + ): + # TODO: Replace `required_includes` by using a property + # Includes attached this way may currently easily be lost during postprocessing, + # since they are not part of `_args` + self._name = name + self._dtype = create_type(dtype) + + self._required_includes = ( + required_includes if required_includes is not None else set() + ) + + @property + def name(self) -> str: + return self._name + + @property + def dtype(self) -> PsType: + return self._dtype + + def _args(self) -> tuple[Any, ...]: + return (self._name, self._dtype) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SfgVar): + return False + + return self._args() == other._args() + + def __hash__(self) -> int: + return hash(self._args()) + + @property + def required_includes(self) -> set[SfgHeaderInclude]: + return self._required_includes + + def name_and_type(self) -> str: + return f"{self._name}: {self._dtype}" + + def __str__(self) -> str: + return self._name + + def __repr__(self) -> str: + return self.name_and_type() + class DependentExpression: __match_args__ = ("expr", "depends") @@ -50,14 +128,41 @@ class DependentExpression: return DependentExpression(self.expr + other.expr, self.depends | other.depends) +class VarExpr(DependentExpression): + def __init__(self, var: SfgVar): + self._var = var + super().__init__(var.name, (var,)) + + @property + def variable(self) -> SfgVar: + return self._var + + class AugExpr: - def __init__(self, dtype: PsType | None = None): - self._dtype = dtype + """C++ expression augmented with variable dependencies and a type-dependent interface. + + `AugExpr` is the primary class for modelling C++ expressions in *pystencils-sfg*. + It stores both an expression's code string and the set of variables (`SfgVar`) + the expression depends on. This dependency information is used by the postprocessing + system to infer function parameter lists. + + In addition, subclasses of `AugExpr` can mimic C++ APIs by defining factory methods that + build expressions for C++ method calls, etc., from a list of argument expressions. + + Args: + dtype: Optional, data type of this expression interface + """ + + __match_args__ = ("expr", "dtype") + + def __init__(self, dtype: UserTypeSpec | None = None): + self._dtype = create_type(dtype) if dtype is not None else None self._bound: DependentExpression | None = None + self._is_variable = False def var(self, name: str): v = SfgVar(name, self.get_dtype(), self.required_includes) - expr = DependentExpression(name, (v,)) + expr = VarExpr(v) return self._bind(expr) @staticmethod @@ -66,14 +171,26 @@ class AugExpr: @staticmethod def format(fmt: str, *deps, **kwdeps) -> AugExpr: + """Create a new `AugExpr` by combining existing expressions.""" return AugExpr().bind(fmt, *deps, **kwdeps) def bind(self, fmt: str, *deps, **kwdeps): - depends = filter( - lambda obj: isinstance(obj, (SfgVar, AugExpr)), chain(deps, kwdeps.values()) - ) + dependencies: set[SfgVar] = set() + + from pystencils.sympyextensions import is_constant + + for expr in chain(deps, kwdeps.values()): + if isinstance(expr, _ExprLike): + dependencies |= depends(expr) + elif isinstance(expr, sp.Expr) and not is_constant(expr): + raise ValueError( + f"Cannot parse SymPy expression as C++ expression: {expr}\n" + " * pystencils-sfg is currently unable to parse non-constant SymPy expressions " + "since they contain symbols without type information." + ) + code = fmt.format(*deps, **kwdeps) - self._bind(DependentExpression(code, depends)) + self._bind(DependentExpression(code, dependencies)) return self def expr(self) -> DependentExpression: @@ -82,6 +199,12 @@ class AugExpr: return self._bound + @property + def code(self) -> str: + if self._bound is None: + raise SfgException("No syntax bound to this AugExpr.") + return str(self._bound) + @property def depends(self) -> set[SfgVar]: if self._bound is None: @@ -99,6 +222,15 @@ class AugExpr: return self._dtype + @property + def is_variable(self) -> bool: + return isinstance(self._bound, VarExpr) + + def as_variable(self) -> SfgVar: + if not isinstance(self._bound, VarExpr): + raise SfgException("This expression is not a variable") + return self._bound.variable + @property def required_includes(self) -> set[SfgHeaderInclude]: return set() @@ -109,6 +241,9 @@ class AugExpr: else: return str(self._bound) + def __repr__(self) -> str: + return str(self) + def _bind(self, expr: DependentExpression): if self._bound is not None: raise SfgException("Attempting to bind an already-bound AugExpr.") @@ -120,6 +255,88 @@ class AugExpr: return self._bound is not None +_VarLike = (AugExpr, SfgVar, TypedSymbol) +VarLike: TypeAlias = AugExpr | SfgVar | TypedSymbol +"""Things that may act as a variable. + +Variable-like objects are entities from pystencils and pystencils-sfg that define +a variable name and data type. +Any `VarLike` object can be transformed into a canonical representation (i.e. `SfgVar`) +using `asvar`. +""" + + +_ExprLike = (str, AugExpr, SfgVar, TypedSymbol) +ExprLike: TypeAlias = str | AugExpr | SfgVar | TypedSymbol +"""Things that may act as a C++ expression. + +This type combines all objects that *pystencils-sfg* can handle in the place of C++ +expressions. These include all valid variable types (`VarLike`), plain strings, and +complex expressions with variable dependency information (`AugExpr`). + +The set of variables an expression depends on can be determined using `depends`. +""" + + +def asvar(var: VarLike) -> SfgVar: + """Cast a variable-like object to its canonical representation, + + Args: + var: Variable-like object + + Returns: + SfgVar: Variable cast as `SfgVar`. + + Raises: + ValueError: If given a non-variable `AugExpr`, + a `TypedSymbol` with a `DynamicType`, + or any non-variable-like object. + """ + match var: + case SfgVar(): + return var + case AugExpr(): + return var.as_variable() + case TypedSymbol(): + from pystencils import DynamicType + + if isinstance(var.dtype, DynamicType): + raise ValueError( + f"Unable to cast dynamically typed symbol {var} to a variable.\n" + f"{var} has dynamic type {var.dtype}, which cannot be resolved to a type outside of a kernel." + ) + + return SfgVar(var.name, var.dtype) + case _: + raise ValueError(f"Invalid variable: {var}") + + +def depends(expr: ExprLike) -> set[SfgVar]: + """Determine the set of variables an expression depends on. + + Args: + expr: Expression-like object to examine + + Returns: + set[SfgVar]: Set of variables the expression depends on + + Raises: + ValueError: If the argument was not a valid expression + """ + + match expr: + case None | str(): + return set() + case SfgVar(): + return {expr} + case TypedSymbol(): + return {asvar(expr)} + case AugExpr(): + return expr.depends + case _: + raise ValueError(f"Invalid expression: {expr}") + + class IFieldExtraction(ABC): """Interface for objects defining how to extract low-level field parameters from high-level data structures.""" @@ -138,7 +355,11 @@ class IFieldExtraction(ABC): class SrcField(AugExpr): - """Represents a C++ data structure that can be mapped to a *pystencils* field.""" + """Represents a C++ data structure that can be mapped to a *pystencils* field. + + Args: + dtype: Data type of the field data structure + """ @abstractmethod def get_extraction(self) -> IFieldExtraction: @@ -146,7 +367,11 @@ class SrcField(AugExpr): class SrcVector(AugExpr, ABC): - """Represents a C++ data structure that represents a mathematical vector.""" + """Represents a C++ data structure that represents a mathematical vector. + + Args: + dtype: Data type of the vector data structure + """ @abstractmethod def extract_component(self, coordinate: int) -> AugExpr: diff --git a/src/pystencilssfg/lang/types.py b/src/pystencilssfg/lang/types.py new file mode 100644 index 0000000000000000000000000000000000000000..6f23160075050c6dfa33fd17636c2a3f826f263a --- /dev/null +++ b/src/pystencilssfg/lang/types.py @@ -0,0 +1,26 @@ +from typing import Any +from pystencils.types import PsType + + +class Ref(PsType): + """C++ reference type.""" + + __match_args__ = "base_type" + + def __init__(self, base_type: PsType, const: bool = False): + super().__init__(False) + self._base_type = base_type + + def __args__(self) -> tuple[Any, ...]: + return (self.base_type,) + + @property + def base_type(self) -> PsType: + return self._base_type + + def c_string(self) -> str: + base_str = self.base_type.c_string() + return base_str + "&" + + def __repr__(self) -> str: + return f"Ref({repr(self.base_type)})" diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/compressed_pair.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/compressed_pair.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/compressed_pair.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/compressed_pair.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/config.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/config.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/config.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/config.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/default_accessor.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/default_accessor.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/default_accessor.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/default_accessor.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/dynamic_extent.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/dynamic_extent.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/dynamic_extent.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/dynamic_extent.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/extents.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/extents.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/extents.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/extents.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/full_extent_t.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/full_extent_t.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/full_extent_t.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/full_extent_t.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_left.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_left.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_left.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_left.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_right.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_right.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_right.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_right.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_stride.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_stride.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/layout_stride.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/layout_stride.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/macros.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/macros.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/macros.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/macros.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/mdspan.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/mdspan.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/mdspan.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/mdspan.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/no_unique_address.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/no_unique_address.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/no_unique_address.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/no_unique_address.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/trait_backports.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/trait_backports.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/trait_backports.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/trait_backports.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/type_list.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/type_list.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/type_list.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/type_list.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p0009_bits/utility.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/utility.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p0009_bits/utility.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p0009_bits/utility.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p1684_bits/mdarray.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p1684_bits/mdarray.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p1684_bits/mdarray.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p1684_bits/mdarray.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2389_bits/dims.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2389_bits/dims.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2389_bits/dims.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2389_bits/dims.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2630_bits/strided_slice.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/strided_slice.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2630_bits/strided_slice.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/strided_slice.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan_extents.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan_extents.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan_extents.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan_extents.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan_mapping.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan_mapping.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2630_bits/submdspan_mapping.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2630_bits/submdspan_mapping.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2642_bits/layout_padded.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2642_bits/layout_padded.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2642_bits/layout_padded.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2642_bits/layout_padded.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/__p2642_bits/layout_padded_fwd.hpp b/tests/generator_scripts/deps/mdspan/include/experimental/__p2642_bits/layout_padded_fwd.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/__p2642_bits/layout_padded_fwd.hpp rename to tests/generator_scripts/deps/mdspan/include/experimental/__p2642_bits/layout_padded_fwd.hpp diff --git a/tests/integration/deps/mdspan/include/experimental/mdarray b/tests/generator_scripts/deps/mdspan/include/experimental/mdarray similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/mdarray rename to tests/generator_scripts/deps/mdspan/include/experimental/mdarray diff --git a/tests/integration/deps/mdspan/include/experimental/mdspan b/tests/generator_scripts/deps/mdspan/include/experimental/mdspan similarity index 100% rename from tests/integration/deps/mdspan/include/experimental/mdspan rename to tests/generator_scripts/deps/mdspan/include/experimental/mdspan diff --git a/tests/integration/deps/mdspan/include/mdspan/mdarray.hpp b/tests/generator_scripts/deps/mdspan/include/mdspan/mdarray.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/mdspan/mdarray.hpp rename to tests/generator_scripts/deps/mdspan/include/mdspan/mdarray.hpp diff --git a/tests/integration/deps/mdspan/include/mdspan/mdspan.hpp b/tests/generator_scripts/deps/mdspan/include/mdspan/mdspan.hpp similarity index 100% rename from tests/integration/deps/mdspan/include/mdspan/mdspan.hpp rename to tests/generator_scripts/deps/mdspan/include/mdspan/mdspan.hpp diff --git a/tests/integration/expected/SimpleClasses.cpp b/tests/generator_scripts/expected/SimpleClasses.h similarity index 76% rename from tests/integration/expected/SimpleClasses.cpp rename to tests/generator_scripts/expected/SimpleClasses.h index 9d27ebeddd488ac23d3b87d126f3d17b312e936a..93e93cabdc7bc80331e2f5c7efc3a988f87e841a 100644 --- a/tests/integration/expected/SimpleClasses.cpp +++ b/tests/generator_scripts/expected/SimpleClasses.h @@ -1,5 +1,9 @@ +#pragma once + #include <cstdint> +#define RESTRICT __restrict__ + class Point { public: const int64_t & getX() const { diff --git a/tests/generator_scripts/expected/Structural.h b/tests/generator_scripts/expected/Structural.h new file mode 100644 index 0000000000000000000000000000000000000000..0eb1e25f0704e43172ef77b6cea7f022fd744c42 --- /dev/null +++ b/tests/generator_scripts/expected/Structural.h @@ -0,0 +1,18 @@ +/* + * Expect the unexpected, and you shall never be surprised. + */ + +#pragma once + +#include <cstdint> + +#include <iostream> +#include "config.h" + +namespace awesome { + +#define RESTRICT __restrict__ +#define PI 3.1415 +using namespace std; + +} // namespace awesome diff --git a/tests/generator_scripts/expected/Variables.h b/tests/generator_scripts/expected/Variables.h new file mode 100644 index 0000000000000000000000000000000000000000..96c16d7b306fe7594db0caa02e9168b2f5c86fc6 --- /dev/null +++ b/tests/generator_scripts/expected/Variables.h @@ -0,0 +1,13 @@ +#pragma once + +#include <cstdint> + +#define RESTRICT __restrict__ + +class Scale { +private: + float alpha; +public: + Scale(float alpha) : alpha{ alpha } {} + void operator() (float *const _data_f, float *const _data_g); +}; diff --git a/tests/integration/scripts/SimpleClasses.py b/tests/generator_scripts/scripts/SimpleClasses.py similarity index 91% rename from tests/integration/scripts/SimpleClasses.py rename to tests/generator_scripts/scripts/SimpleClasses.py index a729d1f125e69c5ef3d178f2ed8be1d65cd423a8..26c236f8a200658ee103150b6295f44fbae8647a 100644 --- a/tests/integration/scripts/SimpleClasses.py +++ b/tests/generator_scripts/scripts/SimpleClasses.py @@ -1,12 +1,9 @@ from pystencilssfg import SourceFileGenerator with SourceFileGenerator() as sfg: - - sfg.include("<cstdint>") - sfg.klass("Point")( sfg.public( - sfg.method("getX", returns="const int64_t &", const=True)( + sfg.method("getX", returns="const int64_t &", const=True, inline=True)( "return this->x;" ) ), diff --git a/tests/integration/scripts/SimpleJacobi.py b/tests/generator_scripts/scripts/SimpleJacobi.py similarity index 99% rename from tests/integration/scripts/SimpleJacobi.py rename to tests/generator_scripts/scripts/SimpleJacobi.py index 199419c541fb4aef68d44a1baf7cfc2f28d67e86..e84c872b05c354f4e1473b2eab17781f0880f035 100644 --- a/tests/integration/scripts/SimpleJacobi.py +++ b/tests/generator_scripts/scripts/SimpleJacobi.py @@ -20,4 +20,4 @@ with SourceFileGenerator() as sfg: sfg.map_field(u_dst, mdspan_ref(u_dst)), sfg.map_field(f, mdspan_ref(f)), sfg.call(poisson_kernel) - ) \ No newline at end of file + ) diff --git a/tests/generator_scripts/scripts/Structural.py b/tests/generator_scripts/scripts/Structural.py new file mode 100644 index 0000000000000000000000000000000000000000..e8cf2ab0d87ecf770871c6ef2d15deb7820b4561 --- /dev/null +++ b/tests/generator_scripts/scripts/Structural.py @@ -0,0 +1,16 @@ +from pystencilssfg import SourceFileGenerator, SfgConfiguration, SfgCodeStyle + +# Do not use clang-format, since it reorders headers +cfg = SfgConfiguration( + codestyle=SfgCodeStyle(skip_clang_format=True) +) + +with SourceFileGenerator(cfg) as sfg: + sfg.prelude("Expect the unexpected, and you shall never be surprised.") + sfg.include("<iostream>") + sfg.include("config.h") + + sfg.namespace("awesome") + + sfg.code("#define PI 3.1415") + sfg.code("using namespace std;") diff --git a/tests/generator_scripts/scripts/Variables.py b/tests/generator_scripts/scripts/Variables.py new file mode 100644 index 0000000000000000000000000000000000000000..9fd4e0027104451738abea72e15084047cf4465e --- /dev/null +++ b/tests/generator_scripts/scripts/Variables.py @@ -0,0 +1,22 @@ +import sympy as sp +from pystencils import TypedSymbol, fields, kernel + +from pystencilssfg import SourceFileGenerator, SfgConfiguration + +with SourceFileGenerator() as sfg: + α = TypedSymbol("alpha", "float32") + f, g = fields("f, g: float32[10]") + + @kernel + def scale(): + f[0] @= α * g.center() + + khandle = sfg.kernels.create(scale) + + sfg.klass("Scale")( + sfg.private(α), + sfg.public( + sfg.constructor(α).init(α)(α.name), + sfg.method("operator()")(sfg.init(α)(f"this->{α}"), sfg.call(khandle)), + ), + ) diff --git a/tests/integration/test_generator_scripts.py b/tests/generator_scripts/test_generator_scripts.py similarity index 58% rename from tests/integration/test_generator_scripts.py rename to tests/generator_scripts/test_generator_scripts.py index 33f747820b966bd7a2630e9497cbb00a7937340c..51e05c01f2cae3a637e235e4a94d1585b390a5af 100644 --- a/tests/integration/test_generator_scripts.py +++ b/tests/generator_scripts/test_generator_scripts.py @@ -14,21 +14,59 @@ EXPECTED_DIR = path.join(THIS_DIR, "expected") @dataclass class ScriptInfo: + @staticmethod + def make(name, *args, **kwargs): + return pytest.param(ScriptInfo(name, *args, **kwargs), id=f"{name}.py") + script_name: str + """Name of the generator script, without .py-extension. + + Generator scripts must be located in the ``scripts`` folder. + """ + expected_outputs: tuple[str, ...] + """List of file extensions expected to be emitted by the generator script. + + Output files will all be placed in the ``out`` folder. + """ compilable_output: str | None = None + """File extension of the output file that can be compiled. + + If this is set, and the expected file exists, the ``compile_cmd`` will be + executed to check for error-free compilation of the output. + """ + compile_cmd: str = f"g++ --std=c++17 -I {THIS_DIR}/deps/mdspan/include" + """Command to be invoked to compile the generated source file.""" + + def __repr__(self) -> str: + return self.script_name + +"""Scripts under test. +When adding new generator scripts to the `scripts` directory, +do not forget to include them here. +""" SCRIPTS = [ - ScriptInfo("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"), - ScriptInfo("SimpleClasses", ("h", "cpp")), + ScriptInfo.make("Structural", ("h", "cpp")), + ScriptInfo.make("SimpleJacobi", ("h", "cpp"), compilable_output="cpp"), + ScriptInfo.make("SimpleClasses", ("h", "cpp")), + ScriptInfo.make("Variables", ("h", "cpp"), compilable_output="cpp"), ] @pytest.mark.parametrize("script_info", SCRIPTS) def test_generator_script(script_info: ScriptInfo): + """Test a generator script defined by ``script_info``. + + The generator script will be run, with its output placed in the ``out`` folder. + If it is successful, its output files will be compared against + any files of the same name from the ``expected`` folder. + Finally, if any compilable files are specified, the test will attempt to compile them. + """ + script_name = script_info.script_name script_file = path.join(SCRIPTS_DIR, script_name + ".py") @@ -67,7 +105,7 @@ def test_generator_script(script_info: ScriptInfo): # Strip whitespace expected = "".join(expected.split()) - actual = "".join(expected.split()) + actual = "".join(actual.split()) assert expected == actual diff --git a/tests/ir/test_postprocessing.py b/tests/ir/test_postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..e144024fa831bf5c31c17b7efc93d82857654671 --- /dev/null +++ b/tests/ir/test_postprocessing.py @@ -0,0 +1,77 @@ +import sympy as sp +from pystencils import fields, kernel, TypedSymbol + +from pystencilssfg import SfgContext, SfgComposer +from pystencilssfg.composer import make_sequence + +from pystencilssfg.ir import SfgStatements +from pystencilssfg.ir.postprocessing import CallTreePostProcessing + + +def test_live_vars(): + ctx = SfgContext() + sfg = SfgComposer(ctx) + + f, g = fields("f, g(2): double[2D]") + x, y = [TypedSymbol(n, "double") for n in "xy"] + z = sp.Symbol("z") + + @kernel + def update(): + f[0, 0] @= x * g.center(0) + y * g.center(1) - z + + khandle = sfg.kernels.create(update) + + a = sfg.var("a", "float") + b = sfg.var("b", "float") + + call_tree = make_sequence( + sfg.init(x)(a), sfg.init(y)(sfg.expr("{} - {}", b, x)), sfg.call(khandle) # # + ) + + pp = CallTreePostProcessing() + free_vars = pp.get_live_variables(call_tree) + + expected = {a.as_variable(), b.as_variable()} | set( + param for param in khandle.parameters if param.name not in "xy" + ) + + assert free_vars == expected + + +def test_find_sympy_symbols(): + ctx = SfgContext() + sfg = SfgComposer(ctx) + + f, g = fields("f, g(2): double[2D]") + x, y, z = sp.symbols("x, y, z") + + @kernel + def update(): + f[0, 0] @= x * g.center(0) + y * g.center(1) - z + + khandle = sfg.kernels.create(update) + + a = sfg.var("a", "double") + b = sfg.var("b", "double") + + call_tree = make_sequence( + sfg.set_param(x, b), + sfg.set_param(y, sfg.expr("{} / {}", x.name, a)), + sfg.call(khandle), + ) + + pp = CallTreePostProcessing() + live_vars = pp.get_live_variables(call_tree) + + expected = {a.as_variable(), b.as_variable()} | set( + param for param in khandle.parameters if param.name not in "xy" + ) + + assert live_vars == expected + + assert isinstance(call_tree.children[0], SfgStatements) + assert call_tree.children[0].code_string == "const double x = b;" + + assert isinstance(call_tree.children[1], SfgStatements) + assert call_tree.children[1].code_string == "const double y = x / a;" diff --git a/tests/lang/test_expressions.py b/tests/lang/test_expressions.py new file mode 100644 index 0000000000000000000000000000000000000000..ef2f1943c07b3f0e1a23750302f043c0ebe89105 --- /dev/null +++ b/tests/lang/test_expressions.py @@ -0,0 +1,95 @@ +import pytest + +from pystencilssfg import SfgException +from pystencilssfg.lang import asvar, SfgVar, AugExpr + +import sympy as sp + +from pystencils import TypedSymbol, DynamicType + + +def test_asvar(): + # SfgVar must be returned as-is + var = SfgVar("p", "uint64") + assert var is asvar(var) + + # TypedSymbol is transformed + ts = TypedSymbol("q", "int32") + assert asvar(ts) == SfgVar("q", "int32") + + # Variable AugExprs get lowered to SfgVar + augexpr = AugExpr("uint16").var("l") + assert asvar(augexpr) == SfgVar("l", "uint16") + + # Complex AugExprs cannot be parsed + cexpr = AugExpr.format("{} + {}", SfgVar("m", "int32"), AugExpr("int32").var("n")) + with pytest.raises(SfgException): + _ = asvar(cexpr) + + # Untyped SymPy symbols won't be parsed + x = sp.Symbol("x") + with pytest.raises(ValueError): + _ = asvar(x) + + # Dynamically typed TypedSymbols cannot be parsed + y = TypedSymbol("y", DynamicType.NUMERIC_TYPE) + with pytest.raises(ValueError): + _ = asvar(y) + + +def test_augexpr_format(): + expr = AugExpr.format("std::vector< real_t > {{ 0.1, 0.2, 0.3 }}") + assert expr.code == "std::vector< real_t > { 0.1, 0.2, 0.3 }" + assert not expr.depends + + expr = AugExpr("int").var("p") + assert expr.code == "p" + assert expr.depends == {SfgVar("p", "int")} + + expr = AugExpr.format( + "{} + {} / {}", + AugExpr("int").var("p"), + AugExpr("int").var("q"), + AugExpr("uint32").var("r"), + ) + + assert str(expr) == expr.code == "p + q / r" + + assert expr.depends == { + SfgVar("p", "int"), + SfgVar("q", "int"), + SfgVar("r", "uint32"), + } + + # Must find TypedSymbols as dependencies + expr = AugExpr.format( + "{} + {} / {}", + AugExpr("int").var("p"), + TypedSymbol("x", "int32"), + TypedSymbol("y", "int32"), + ) + + assert expr.code == "p + x / y" + assert expr.depends == { + SfgVar("p", "int"), + SfgVar("x", "int32"), + SfgVar("y", "int32"), + } + + # Can parse constant SymPy expressions + expr = AugExpr.format("{}", sp.sympify(1)) + + assert expr.code == "1" + assert not expr.depends + + +def test_augexpr_illegal_format(): + x, y, z = sp.symbols("x, y, z") + + with pytest.raises(ValueError): + # Cannot parse SymPy symbols + _ = AugExpr.format("{}", x) + + with pytest.raises(ValueError): + # Cannot parse expressions containing symbols + _ = AugExpr.format("{} + {}", x + 3, y / (2 * z))