Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
Showing
with 904 additions and 110 deletions
pystencils/_version.py export-subst
src/pystencils/_version.py export-subst
......@@ -12,7 +12,7 @@ authors = [
]
license = { file = "COPYING.txt" }
requires-python = ">=3.10"
dependencies = ["sympy>=1.6,<=1.11.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
dependencies = ["sympy>=1.9,<=1.12.1", "numpy>=1.8.0", "appdirs", "joblib", "pyyaml"]
classifiers = [
"Development Status :: 4 - Beta",
"Framework :: Jupyter",
......@@ -71,8 +71,7 @@ tests = [
[build-system]
requires = [
"setuptools>=61",
"versioneer>=0.29",
"tomli; python_version < '3.11'",
"versioneer[toml]>=0.29",
# 'Cython'
]
build-backend = "setuptools.build_meta"
......
......@@ -15,7 +15,7 @@ from .config import (
OpenMpConfig,
)
from .kernel_decorator import kernel, kernel_config
from .kernelcreation import create_kernel
from .kernelcreation import create_kernel, create_staggered_kernel
from .backend.kernelfunction import KernelFunction
from .slicing import make_slice
from .spatial_coordinates import (
......@@ -28,12 +28,13 @@ from .spatial_coordinates import (
z_,
z_staggered,
)
from .sympyextensions import Assignment, AssignmentCollection, AddAugmentedAssignment
from .sympyextensions.astnodes import assignment_from_stencil
from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .simp import AssignmentCollection
from .sympyextensions.typed_sympy import TypedSymbol
from .sympyextensions.math import SymbolCreator
from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
__all__ = [
"Field",
"FieldType",
......@@ -48,6 +49,7 @@ __all__ = [
"VectorizationConfig",
"OpenMpConfig",
"create_kernel",
"create_staggered_kernel",
"KernelFunction",
"Target",
"show_code",
......@@ -75,7 +77,5 @@ __all__ = [
"stencil",
]
from ._version import get_versions
__version__ = get_versions()["version"]
del get_versions
from . import _version
__version__ = _version.get_versions()['version']
......@@ -5,8 +5,9 @@
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain. Generated by
# versioneer-0.19 (https://github.com/python-versioneer/python-versioneer)
# This file is released into the public domain.
# Generated by versioneer-0.29
# https://github.com/python-versioneer/python-versioneer
"""Git implementation of _version.py."""
......@@ -15,9 +16,11 @@ import os
import re
import subprocess
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple
import functools
def get_keywords():
def get_keywords() -> Dict[str, str]:
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
......@@ -33,8 +36,15 @@ def get_keywords():
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
VCS: str
style: str
tag_prefix: str
parentdir_prefix: str
versionfile_source: str
verbose: bool
def get_config():
def get_config() -> VersioneerConfig:
"""Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates
# _version.py
......@@ -43,7 +53,7 @@ def get_config():
cfg.style = "pep440"
cfg.tag_prefix = "release/"
cfg.parentdir_prefix = "pystencils-"
cfg.versionfile_source = "pystencils/_version.py"
cfg.versionfile_source = "src/pystencils/_version.py"
cfg.verbose = False
return cfg
......@@ -52,13 +62,13 @@ class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY = {}
HANDLERS = {}
LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator
"""Create decorator to mark a method as the handler of a VCS."""
def decorate(f):
def decorate(f: Callable) -> Callable:
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
......@@ -67,22 +77,35 @@ def register_vcs_handler(vcs, method): # decorator
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
env=None):
def run_command(
commands: List[str],
args: List[str],
cwd: Optional[str] = None,
verbose: bool = False,
hide_stderr: bool = False,
env: Optional[Dict[str, str]] = None,
) -> Tuple[Optional[str], Optional[int]]:
"""Call the given command(s)."""
assert isinstance(commands, list)
p = None
for c in commands:
process = None
popen_kwargs: Dict[str, Any] = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try:
dispcmd = str([c] + args)
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
p = subprocess.Popen([c] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None))
process = subprocess.Popen([command] + args, cwd=cwd, env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr
else None), **popen_kwargs)
break
except EnvironmentError:
e = sys.exc_info()[1]
except OSError as e:
if e.errno == errno.ENOENT:
continue
if verbose:
......@@ -93,16 +116,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = p.communicate()[0].strip().decode()
if p.returncode != 0:
stdout = process.communicate()[0].strip().decode()
if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
return None, p.returncode
return stdout, p.returncode
return None, process.returncode
return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
def versions_from_parentdir(
parentdir_prefix: str,
root: str,
verbose: bool,
) -> Dict[str, Any]:
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
......@@ -111,15 +138,14 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
"""
rootdirs = []
for i in range(3):
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {"version": dirname[len(parentdir_prefix):],
"full-revisionid": None,
"dirty": False, "error": None, "date": None}
else:
rootdirs.append(root)
root = os.path.dirname(root) # up a level
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print("Tried directories %s but none started with prefix %s" %
......@@ -128,39 +154,42 @@ def versions_from_parentdir(parentdir_prefix, root, verbose):
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs):
def git_get_keywords(versionfile_abs: str) -> Dict[str, str]:
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords = {}
keywords: Dict[str, str] = {}
try:
f = open(versionfile_abs, "r")
for line in f.readlines():
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
f.close()
except EnvironmentError:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
except OSError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
def git_versions_from_keywords(
keywords: Dict[str, str],
tag_prefix: str,
verbose: bool,
) -> Dict[str, Any]:
"""Get version information from git keywords."""
if not keywords:
raise NotThisMethod("no keywords at all, weird")
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
......@@ -179,11 +208,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = set([r.strip() for r in refnames.strip("()").split(",")])
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])
tags = {r[len(TAG):] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
......@@ -192,7 +221,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = set([r for r in refs if re.search(r'\d', r)])
tags = {r for r in refs if re.search(r'\d', r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
......@@ -201,6 +230,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix):]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r'\d', r):
continue
if verbose:
print("picking %s" % r)
return {"version": r,
......@@ -216,7 +250,12 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose):
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
def git_pieces_from_vcs(
tag_prefix: str,
root: str,
verbose: bool,
runner: Callable = run_command
) -> Dict[str, Any]:
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
......@@ -227,8 +266,15 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=True)
# GIT_DIR can interfere with correct operation of Versioneer.
# It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root,
hide_stderr=not verbose)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
......@@ -236,24 +282,57 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty",
"--always", "--long",
"--match", "%s*" % tag_prefix],
cwd=root)
describe_out, rc = runner(GITS, [
"describe", "--tags", "--dirty", "--always", "--long",
"--match", f"{tag_prefix}[[:digit:]]*"
], cwd=root)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root)
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces = {}
pieces: Dict[str, Any] = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"],
cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
......@@ -270,7 +349,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
# TAG-NUM-gHEX
mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe)
if not mo:
# unparseable. Maybe git-describe is misbehaving?
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = ("unable to parse git-describe output: '%s'"
% describe_out)
return pieces
......@@ -295,13 +374,11 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
else:
# HEX: no tags
pieces["closest-tag"] = None
count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"],
cwd=root)
pieces["distance"] = int(count_out) # total number of commits
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"],
cwd=root)[0].strip()
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
......@@ -310,14 +387,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):
return pieces
def plus_or_dot(pieces):
def plus_or_dot(pieces: Dict[str, Any]) -> str:
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces):
def render_pep440(pieces: Dict[str, Any]) -> str:
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
......@@ -342,23 +419,71 @@ def render_pep440(pieces):
return rendered
def render_pep440_pre(pieces):
"""TAG[.post0.devDISTANCE] -- No -dirty.
def render_pep440_branch(pieces: Dict[str, Any]) -> str:
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions:
1: no tags. 0.post0.devDISTANCE
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%d.g%s" % (pieces["distance"],
pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]:
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces: Dict[str, Any]) -> str:
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]:
rendered += ".post0.dev%d" % pieces["distance"]
# update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
return rendered
def render_pep440_post(pieces):
def render_pep440_post(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
......@@ -385,7 +510,36 @@ def render_pep440_post(pieces):
return rendered
def render_pep440_old(pieces):
def render_pep440_post_branch(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces: Dict[str, Any]) -> str:
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
......@@ -407,7 +561,7 @@ def render_pep440_old(pieces):
return rendered
def render_git_describe(pieces):
def render_git_describe(pieces: Dict[str, Any]) -> str:
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
......@@ -427,7 +581,7 @@ def render_git_describe(pieces):
return rendered
def render_git_describe_long(pieces):
def render_git_describe_long(pieces: Dict[str, Any]) -> str:
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
......@@ -447,7 +601,7 @@ def render_git_describe_long(pieces):
return rendered
def render(pieces, style):
def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]:
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {"version": "unknown",
......@@ -461,10 +615,14 @@ def render(pieces, style):
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
......@@ -479,7 +637,7 @@ def render(pieces, style):
"date": pieces.get("date")}
def get_versions():
def get_versions() -> Dict[str, Any]:
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
......@@ -500,7 +658,7 @@ def get_versions():
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for i in cfg.versionfile_source.split('/'):
for _ in cfg.versionfile_source.split('/'):
root = os.path.dirname(root)
except NameError:
return {"version": "0+unknown", "full-revisionid": None,
......
import numpy as np
import sympy as sp
from sympy.codegen.ast import Assignment, AugmentedAssignment
from sympy.codegen.ast import AddAugmentedAssignment as SpAddAugAssignment
from sympy.printing.latex import LatexPrinter
__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
def print_assignment_latex(printer, expr):
binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs)
return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
def assignment_str(assignment):
op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else ''
return fr"{assignment.lhs} {op} {assignment.rhs}"
_old_new = sp.codegen.ast.Assignment.__new__
# TODO Typing Part2 add default type, defult_float_type, default_int_type and use sane defaults
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
AugmentedAssignment.__str__ = assignment_str
LatexPrinter._print_AugmentedAssignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
# Re-Export
AddAugmentedAssignment = SpAddAugAssignment
def assignment_from_stencil(stencil_array, input_field, output_field,
normalization_factor=None, order='visual') -> Assignment:
"""Creates an assignment
Args:
stencil_array: nested list of numpy array defining the stencil weights
input_field: field or field access, defining where the stencil should be applied to
output_field: field or field access where the result is written to
normalization_factor: optional normalization factor for the stencil
order: defines how the stencil_array is interpreted. Possible values are 'visual' and 'numpy'.
For details see examples
Returns:
Assignment that can be used to create a kernel
Examples:
>>> import pystencils as ps
>>> f, g = ps.fields("f, g: [2D]")
>>> stencil = [[0, 2, 0],
... [3, 4, 5],
... [0, 6, 0]]
By default 'visual ordering is used - i.e. the stencil is applied as the nested lists are written down
>>> expected_output = Assignment(g[0, 0], 3*f[-1, 0] + 6*f[0, -1] + 4*f[0, 0] + 2*f[0, 1] + 5*f[1, 0])
>>> assignment_from_stencil(stencil, f, g, order='visual') == expected_output
True
'numpy' ordering uses the first coordinate of the stencil array for x offset, second for y offset etc.
>>> expected_output = Assignment(g[0, 0], 2*f[-1, 0] + 3*f[0, -1] + 4*f[0, 0] + 5*f[0, 1] + 6*f[1, 0])
>>> assignment_from_stencil(stencil, f, g, order='numpy') == expected_output
True
You can also pass field accesses to apply the stencil at an already shifted position:
>>> expected_output = Assignment(g[2, 0], 3*f[0, 0] + 6*f[1, -1] + 4*f[1, 0] + 2*f[1, 1] + 5*f[2, 0])
>>> assignment_from_stencil(stencil, f[1, 0], g[2, 0]) == expected_output
True
"""
from pystencils.field import Field
stencil_array = np.array(stencil_array)
if order == 'visual':
stencil_array = np.swapaxes(stencil_array, 0, 1)
stencil_array = np.flip(stencil_array, axis=1)
elif order == 'numpy':
pass
else:
raise ValueError("'order' has to be either 'visual' or 'numpy'")
if isinstance(input_field, Field):
input_field = input_field.center
if isinstance(output_field, Field):
output_field = output_field.center
rhs = 0
offset = tuple(s // 2 for s in stencil_array.shape)
for index, factor in np.ndenumerate(stencil_array):
shift = tuple(i - o for i, o in zip(index, offset))
rhs += factor * input_field.get_shifted(*shift)
if normalization_factor:
rhs *= normalization_factor
return Assignment(output_field, rhs)
......@@ -9,7 +9,8 @@ import sympy as sp
from .context import KernelCreationContext
from ...field import Field
from ...sympyextensions import Assignment, AssignmentCollection
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ..exceptions import PsInternalCompilerError, KernelConstraintsError
......
......@@ -7,7 +7,8 @@ import sympy.core.relational
import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...sympyextensions.astnodes import Assignment, AssignmentCollection
from ...assignment import Assignment
from ...simp import AssignmentCollection
from ...sympyextensions import (
integer_functions,
ConditionalFieldAccess,
......
......@@ -6,7 +6,7 @@ from functools import reduce
from operator import mul
from ...defaults import DEFAULTS
from ...sympyextensions import AssignmentCollection
from ...simp import AssignmentCollection
from ...field import Field, FieldType
from ..symbols import PsSymbol
......
from typing import Any, List, Tuple, Sequence
from pystencils.sympyextensions import Assignment
from pystencils.assignment import Assignment
from pystencils.boundaries.boundaryhandling import BoundaryOffsetInfo
from pystencils.types import create_type
......
......@@ -4,7 +4,7 @@ import numpy as np
import sympy as sp
from pystencils import create_kernel, CreateKernelConfig, Target
from pystencils.sympyextensions import Assignment
from pystencils.assignment import Assignment
from pystencils.boundaries.createindexlist import (
create_boundary_index_array, numpy_data_type_for_boundary_object)
from pystencils.sympyextensions import TypedSymbol
......
......@@ -6,7 +6,7 @@ import sympy as sp
from pystencils.field import Field
from pystencils.stencil import direction_string_to_offset
from pystencils.sympyextensions.math import multidimensional_sum, prod
from pystencils.sympyextensions import multidimensional_sum, prod
from pystencils.utils import LinearEquationSystem, fully_contains
......
......@@ -3,7 +3,7 @@ from collections import defaultdict, namedtuple
import sympy as sp
from pystencils.field import Field
from pystencils.sympyextensions.math import normalize_product, prod
from pystencils.sympyextensions import normalize_product, prod
def _default_diff_sort_key(d):
......
......@@ -7,8 +7,8 @@ from pystencils.fd import Diff
from pystencils.fd.derivative import diff_args
from pystencils.fd.spatial import fd_stencils_standard
from pystencils.field import Field
from pystencils.sympyextensions import AssignmentCollection
from pystencils.sympyextensions.math import fast_subs
from pystencils.simp import AssignmentCollection
from pystencils.sympyextensions import fast_subs
FieldOrFieldAccess = Union[Field, Field.Access]
......
......@@ -15,8 +15,8 @@ from pystencils.alignedarray import aligned_empty
from pystencils.spatial_coordinates import x_staggered_vector, x_vector
from pystencils.stencil import direction_string_to_offset, inverse_direction, offset_to_direction_string
from pystencils.types import PsType, PsStructType, create_type
from pystencils.sympyextensions.typed_sympy import (FieldShapeSymbol, FieldStrideSymbol, TypedSymbol)
from pystencils.sympyextensions.math import is_integer_sequence
from pystencils.sympyextensions.typed_sympy import FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
from pystencils.sympyextensions import is_integer_sequence
from pystencils.types import UserTypeSpec
......
......@@ -5,8 +5,8 @@ from typing import Callable, Union, List, Dict, Tuple
import sympy as sp
from .sympyextensions import Assignment
from .sympyextensions.math import SymbolCreator
from .assignment import Assignment
from .sympyextensions import SymbolCreator
from pystencils.config import CreateKernelConfig
__all__ = ['kernel', 'kernel_config']
......
......@@ -32,7 +32,8 @@ from .backend.transformations import (
SelectFunctions,
)
from .sympyextensions import AssignmentCollection, Assignment
from .simp import AssignmentCollection
from .assignment import Assignment
__all__ = ["create_kernel"]
......@@ -152,3 +153,7 @@ def create_kernel_function(
return KernelFunction(
body, target_spec, function_name, params, req_headers, ctx.constraints, jit
)
def create_staggered_kernel(assignments, target: Target = Target.CPU, gpu_exclusive_conditions=False, **kwargs):
raise NotImplementedError("Staggered kernels are not yet implemented for pystencils 2.0")
......@@ -2,7 +2,7 @@ from typing import List
import sympy as sp
from pystencils.sympyextensions import Assignment
from .assignment import Assignment
from pystencils.sympyextensions import is_constant
from pystencils.sympyextensions.astnodes import generic_visit
......
......@@ -2,10 +2,9 @@ import copy
import numpy as np
import sympy as sp
from pystencils.sympyextensions import TypedSymbol, CastFunc
from .sympyextensions import TypedSymbol, CastFunc, fast_subs
# from pystencils.sympyextensions.astnodes import LoopOverCoordinate # TODO nbackend: replace
# from pystencils.backends.cbackend import CustomCodeNode # TODO nbackend: replace
from pystencils.sympyextensions import fast_subs
# class RNGBase(CustomCodeNode): TODO nbackend: replace
......
from .assignment_collection import AssignmentCollection
from .simplifications import (
add_subexpressions_for_constants,
add_subexpressions_for_divisions,
add_subexpressions_for_field_reads,
add_subexpressions_for_sums,
apply_on_all_subexpressions,
apply_to_all_assignments,
subexpression_substitution_in_existing_subexpressions,
subexpression_substitution_in_main_assignments,
sympy_cse,
sympy_cse_on_assignment_list,
)
from .subexpression_insertion import (
insert_aliases,
insert_zeros,
insert_constants,
insert_constant_additions,
insert_constant_multiples,
insert_squares,
insert_symbol_times_minus_one,
)
from .simplificationstrategy import SimplificationStrategy
__all__ = [
"AssignmentCollection",
"SimplificationStrategy",
"sympy_cse",
"sympy_cse_on_assignment_list",
"apply_to_all_assignments",
"apply_on_all_subexpressions",
"subexpression_substitution_in_existing_subexpressions",
"subexpression_substitution_in_main_assignments",
"add_subexpressions_for_constants",
"add_subexpressions_for_divisions",
"add_subexpressions_for_sums",
"add_subexpressions_for_field_reads",
"insert_aliases",
"insert_zeros",
"insert_constants",
"insert_constant_additions",
"insert_constant_multiples",
"insert_squares",
"insert_symbol_times_minus_one",
]
import itertools
from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from ..assignment import Assignment
from .simplifications import (sort_assignments_topologically, transform_lhs_and_rhs, transform_rhs)
from ..sympyextensions import count_operations, fast_subs
class AssignmentCollection:
"""
A collection of equations with subexpression definitions, also represented as assignments,
that are used in the main equations. AssignmentCollection can be passed to simplification methods.
These simplification methods can change the subexpressions, but the number and
left hand side of the main equations themselves is not altered.
Additionally a dictionary of simplification hints is stored, which are set by the functions that create
assignment collections to transport information to the simplification system.
Args:
main_assignments: List of assignments. Main assignments are characterised, that the right hand side of each
assignment is a field access. Thus the generated equations write on arrays.
subexpressions: List of assignments defining subexpressions used in main equations
simplification_hints: Dict that is used to annotate the assignment collection with hints that are
used by the simplification system. See documentation of the simplification rules for
potentially required hints and their meaning.
subexpression_symbol_generator: Generator for new symbols that are used when new subexpressions are added
used to get new symbols that are unique for this AssignmentCollection
"""
__match_args__ = ("main_assignments", "subexpressions")
# ------------------------------- Creation & Inplace Manipulation --------------------------------------------------
def __init__(self, main_assignments: Union[List[Assignment], Dict[sp.Expr, sp.Expr]],
subexpressions: Union[List[Assignment], Dict[sp.Expr, sp.Expr]] = None,
simplification_hints: Optional[Dict[str, Any]] = None,
subexpression_symbol_generator: Iterator[sp.Symbol] = None) -> None:
if subexpressions is None:
subexpressions = {}
if isinstance(main_assignments, Dict):
main_assignments = [Assignment(k, v)
for k, v in main_assignments.items()]
if isinstance(subexpressions, Dict):
subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments
self.subexpressions = subexpressions
if simplification_hints is None:
simplification_hints = {}
self.simplification_hints = simplification_hints
ctrs = [int(n.name[3:])for n in self.rhs_symbols if "xi_" in n.name]
max_ctr = max(ctrs) + 1 if len(ctrs) > 0 else 0
if subexpression_symbol_generator is None:
self.subexpression_symbol_generator = SymbolGen(ctr=max_ctr)
else:
self.subexpression_symbol_generator = subexpression_symbol_generator
def add_simplification_hint(self, key: str, value: Any) -> None:
"""Adds an entry to the simplification_hints dictionary and checks that is does not exist yet."""
assert key not in self.simplification_hints, "This hint already exists"
self.simplification_hints[key] = value
def add_subexpression(self, rhs: sp.Expr, lhs: Optional[sp.Symbol] = None, topological_sort=True) -> sp.Symbol:
"""Adds a subexpression to current collection.
Args:
rhs: right hand side of new subexpression
lhs: optional left hand side of new subexpression. If None a new unique symbol is generated.
topological_sort: sort the subexpressions topologically after insertion, to make sure that
definition of a symbol comes before its usage. If False, subexpression is appended.
Returns:
left hand side symbol (which could have been generated)
"""
if lhs is None:
lhs = next(self.subexpression_symbol_generator)
eq = Assignment(lhs, rhs)
self.subexpressions.append(eq)
if topological_sort:
self.topological_sort(sort_subexpressions=True,
sort_main_assignments=False)
return lhs
def topological_sort(self, sort_subexpressions: bool = True, sort_main_assignments: bool = True) -> None:
"""Sorts subexpressions and/or main_equations topologically to make sure symbol usage comes after definition."""
if sort_subexpressions:
self.subexpressions = sort_assignments_topologically(self.subexpressions)
if sort_main_assignments:
self.main_assignments = sort_assignments_topologically(self.main_assignments)
# ---------------------------------------------- Properties -------------------------------------------------------
@property
def all_assignments(self) -> List[Assignment]:
"""Subexpression and main equations as a single list."""
return self.subexpressions + self.main_assignments
@property
def rhs_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which occur on the rhs of any assignment."""
rhs_symbols = set()
for eq in self.all_assignments:
if isinstance(eq, Assignment):
rhs_symbols.update(eq.rhs.atoms(sp.Symbol))
# elif isinstance(eq, pystencils.astnodes.Node): # TODO remove or replace
# rhs_symbols.update(eq.undefined_symbols)
return rhs_symbols
@property
def free_symbols(self) -> Set[sp.Symbol]:
"""All symbols used in the assignment collection, which do not occur as left hand sides in any assignment."""
return self.rhs_symbols - self.bound_symbols
@property
def bound_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur on the left hand side of a main assignment or a subexpression."""
bound_symbols_set = set(
[assignment.lhs for assignment in self.all_assignments if isinstance(assignment, Assignment)]
)
assert len(bound_symbols_set) == len(list(a for a in self.all_assignments if isinstance(a, Assignment))), \
"Not in SSA form - same symbol assigned multiple times"
# bound_symbols_set = bound_symbols_set.union(*[
# assignment.symbols_defined for assignment in self.all_assignments
# if isinstance(assignment, pystencils.astnodes.Node)
# ]) TODO: replace?
return bound_symbols_set
@property
def rhs_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.rhs_symbols if hasattr(s, 'field')}
@property
def free_fields(self):
"""All fields accessed in the assignment collection, which do not occur as left hand sides in any assignment."""
return {s.field for s in self.free_symbols if hasattr(s, 'field')}
@property
def bound_fields(self):
"""All field accessed on the left hand side of a main assignment or a subexpression."""
return {s.field for s in self.bound_symbols if hasattr(s, 'field')}
@property
def defined_symbols(self) -> Set[sp.Symbol]:
"""All symbols which occur as left-hand-sides of one of the main equations"""
lhs_set = set([assignment.lhs for assignment in self.main_assignments if isinstance(assignment, Assignment)])
return lhs_set
# return (lhs_set.union(*[assignment.symbols_defined for assignment in self.main_assignments
# if isinstance(assignment, pystencils.astnodes.Node)])) TODO
@property
def operation_count(self):
"""See :func:`count_operations` """
return count_operations(self.all_assignments, only_type=None)
def atoms(self, *args):
return set().union(*[a.atoms(*args) for a in self.all_assignments])
def dependent_symbols(self, symbols: Iterable[sp.Symbol]) -> Set[sp.Symbol]:
"""Returns all symbols that depend on one of the passed symbols.
A symbol 'a' depends on a symbol 'b', if there is an assignment 'a <- some_expression(b)' i.e. when
'b' is required to compute 'a'.
"""
queue = list(symbols)
def add_symbols_from_expr(expr):
dependent_symbols = expr.atoms(sp.Symbol)
for ds in dependent_symbols:
queue.append(ds)
handled_symbols = set()
assignment_dict = {e.lhs: e.rhs for e in self.all_assignments}
while len(queue) > 0:
e = queue.pop(0)
if e in handled_symbols:
continue
if e in assignment_dict:
add_symbols_from_expr(assignment_dict[e])
handled_symbols.add(e)
return handled_symbols
def lambdify(self, symbols: Sequence[sp.Symbol], fixed_symbols: Optional[Dict[sp.Symbol, Any]] = None, module=None):
"""Returns a python function to evaluate this equation collection.
Args:
symbols: symbol(s) which are the parameter for the created function
fixed_symbols: dictionary with substitutions, that are applied before sympy's lambdify
module: same as sympy.lambdify parameter. Defines which module to use e.g. 'numpy'
Examples:
>>> a, b, c, d = sp.symbols("a b c d")
>>> ac = AssignmentCollection([Assignment(c, a + b), Assignment(d, a**2 + b)],
... subexpressions=[Assignment(b, a + b / 2)])
>>> python_function = ac.lambdify([a], fixed_symbols={b: 2})
>>> python_function(4)
{c: 6, d: 18}
"""
assignments = self.new_with_substitutions(fixed_symbols, substitute_on_lhs=False) if fixed_symbols else self
assignments = assignments.new_without_subexpressions().main_assignments
lambdas = {assignment.lhs: sp.lambdify(symbols, assignment.rhs, module) for assignment in assignments}
def f(*args, **kwargs):
return {s: func(*args, **kwargs) for s, func in lambdas.items()}
return f
# ---------------------------- Creating new modified collections ---------------------------------------------------
def copy(self,
main_assignments: Optional[List[Assignment]] = None,
subexpressions: Optional[List[Assignment]] = None) -> 'AssignmentCollection':
"""Returns a copy with optionally replaced main_assignments and/or subexpressions."""
res = copy(self)
res.simplification_hints = self.simplification_hints.copy()
res.subexpression_symbol_generator = copy(self.subexpression_symbol_generator)
if main_assignments is not None:
res.main_assignments = main_assignments
else:
res.main_assignments = self.main_assignments.copy()
if subexpressions is not None:
res.subexpressions = subexpressions
else:
res.subexpressions = self.subexpressions.copy()
return res
def new_with_substitutions(self, substitutions: Dict, add_substitutions_as_subexpressions: bool = False,
substitute_on_lhs: bool = True,
sort_topologically: bool = True) -> 'AssignmentCollection':
"""Returns new object, where terms are substituted according to the passed substitution dict.
Args:
substitutions: dict that is passed to sympy subs, substitutions are done main assignments and subexpressions
add_substitutions_as_subexpressions: if True, the substitutions are added as assignments to subexpressions
substitute_on_lhs: if False, the substitutions are done only on the right hand side of assignments
sort_topologically: if subexpressions are added as substitutions and this parameters is true,
the subexpressions are sorted topologically after insertion
Returns:
New AssignmentCollection where substitutions have been applied, self is not altered.
"""
transform = transform_lhs_and_rhs if substitute_on_lhs else transform_rhs
transformed_subexpressions = transform(self.subexpressions, fast_subs, substitutions)
transformed_assignments = transform(self.main_assignments, fast_subs, substitutions)
if add_substitutions_as_subexpressions:
transformed_subexpressions = [Assignment(b, a) for a, b in
substitutions.items()] + transformed_subexpressions
if sort_topologically:
transformed_subexpressions = sort_assignments_topologically(transformed_subexpressions)
return self.copy(transformed_assignments, transformed_subexpressions)
def new_merged(self, other: 'AssignmentCollection') -> 'AssignmentCollection':
"""Returns a new collection which contains self and other. Subexpressions are renamed if they clash."""
own_definitions = set([e.lhs for e in self.main_assignments])
other_definitions = set([e.lhs for e in other.main_assignments])
assert len(own_definitions.intersection(other_definitions)) == 0, \
"Cannot merge collections, since both define the same symbols"
own_subexpression_symbols = {e.lhs: e.rhs for e in self.subexpressions}
substitution_dict = {}
processed_other_subexpression_equations = []
for other_subexpression_eq in other.subexpressions:
if other_subexpression_eq.lhs in own_subexpression_symbols:
new_rhs = fast_subs(other_subexpression_eq.rhs, substitution_dict)
if new_rhs == own_subexpression_symbols[other_subexpression_eq.lhs]:
continue # exact the same subexpression equation exists already
else:
# different definition - a new name has to be introduced
new_lhs = next(self.subexpression_symbol_generator)
new_eq = Assignment(new_lhs, new_rhs)
processed_other_subexpression_equations.append(new_eq)
substitution_dict[other_subexpression_eq.lhs] = new_lhs
else:
processed_other_subexpression_equations.append(fast_subs(other_subexpression_eq, substitution_dict))
processed_other_main_assignments = [fast_subs(eq, substitution_dict) for eq in other.main_assignments]
return self.copy(self.main_assignments + processed_other_main_assignments,
self.subexpressions + processed_other_subexpression_equations)
def new_filtered(self, symbols_to_extract: Iterable[sp.Symbol]) -> 'AssignmentCollection':
"""Extracts equations that have symbols_to_extract as left hand side, together with necessary subexpressions.
Returns:
new AssignmentCollection, self is not altered
"""
symbols_to_extract = set(symbols_to_extract)
dependent_symbols = self.dependent_symbols(symbols_to_extract)
new_assignments = []
for eq in self.all_assignments:
if eq.lhs in symbols_to_extract:
new_assignments.append(eq)
new_sub_expr = [eq for eq in self.all_assignments
if eq.lhs in dependent_symbols and eq.lhs not in symbols_to_extract]
return self.copy(new_assignments, new_sub_expr)
def new_without_unused_subexpressions(self) -> 'AssignmentCollection':
"""Returns new collection that only contains subexpressions required to compute the main assignments."""
all_lhs = [eq.lhs for eq in self.main_assignments]
return self.new_filtered(all_lhs)
def new_with_inserted_subexpression(self, symbol: sp.Symbol) -> 'AssignmentCollection':
"""Eliminates the subexpression with the given symbol on its left hand side, by substituting it everywhere."""
new_subexpressions = []
subs_dict = None
for se in self.subexpressions:
if se.lhs == symbol:
subs_dict = {se.lhs: se.rhs}
else:
new_subexpressions.append(se)
if subs_dict is None:
return self
new_subexpressions = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in new_subexpressions]
new_eqs = [Assignment(eq.lhs, fast_subs(eq.rhs, subs_dict)) for eq in self.main_assignments]
return self.copy(new_eqs, new_subexpressions)
def new_without_subexpressions(self, subexpressions_to_keep=None) -> 'AssignmentCollection':
"""Returns a new collection where all subexpressions have been inserted."""
if subexpressions_to_keep is None:
subexpressions_to_keep = set()
if len(self.subexpressions) == 0:
return self.copy()
subexpressions_to_keep = set(subexpressions_to_keep)
kept_subexpressions = []
if self.subexpressions[0].lhs in subexpressions_to_keep:
substitution_dict = {}
kept_subexpressions.append(self.subexpressions[0])
else:
substitution_dict = {self.subexpressions[0].lhs: self.subexpressions[0].rhs}
subexpression = [e for e in self.subexpressions]
for i in range(1, len(subexpression)):
subexpression[i] = fast_subs(subexpression[i], substitution_dict)
if subexpression[i].lhs in subexpressions_to_keep:
kept_subexpressions.append(subexpression[i])
else:
substitution_dict[subexpression[i].lhs] = subexpression[i].rhs
new_assignment = [fast_subs(eq, substitution_dict) for eq in self.main_assignments]
return self.copy(new_assignment, kept_subexpressions)
# ----------------------------------------- Display and Printing -------------------------------------------------
def _repr_html_(self):
"""Interface to Jupyter notebook, to display as a nicely formatted HTML table"""
def make_html_equation_table(equations):
no_border = 'style="border:none"'
html_table = '<table style="border:none; width: 100%; ">'
line = '<tr {nb}> <td {nb}>$${eq}$$</td> </tr> '
for eq in equations:
format_dict = {'eq': sp.latex(eq),
'nb': no_border, }
html_table += line.format(**format_dict)
html_table += "</table>"
return html_table
result = ""
if len(self.subexpressions) > 0:
result += "<div>Subexpressions:</div>"
result += make_html_equation_table(self.subexpressions)
result += "<div>Main Assignments:</div>"
result += make_html_equation_table(self.main_assignments)
return result
def __repr__(self):
return f"AssignmentCollection: {str(tuple(self.defined_symbols))[1:-1]} <- f{tuple(self.free_symbols)}"
def __str__(self):
result = "Subexpressions:\n"
for eq in self.subexpressions:
result += f"\t{eq}\n"
result += "Main Assignments:\n"
for eq in self.main_assignments:
result += f"\t{eq}\n"
return result
def __iter__(self):
return self.all_assignments.__iter__()
@property
def main_assignments_dict(self):
return {a.lhs: a.rhs for a in self.main_assignments}
@property
def subexpressions_dict(self):
return {a.lhs: a.rhs for a in self.subexpressions}
def set_main_assignments_from_dict(self, main_assignments_dict):
self.main_assignments = [Assignment(k, v)
for k, v in main_assignments_dict.items()]
def set_sub_expressions_from_dict(self, sub_expressions_dict):
self.subexpressions = [Assignment(k, v)
for k, v in sub_expressions_dict.items()]
def find(self, *args, **kwargs):
return set.union(
*[a.find(*args, **kwargs) for a in self.all_assignments]
)
def match(self, *args, **kwargs):
rtn = {}
for a in self.all_assignments:
partial_result = a.match(*args, **kwargs)
if partial_result:
rtn.update(partial_result)
return rtn
def subs(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.subs(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.subs(*args, **kwargs) for a in self.subexpressions]
)
def replace(self, *args, **kwargs):
return AssignmentCollection(
main_assignments=[a.replace(*args, **kwargs) for a in self.main_assignments],
subexpressions=[a.replace(*args, **kwargs) for a in self.subexpressions]
)
def __eq__(self, other):
return set(self.all_assignments) == set(other.all_assignments)
def __bool__(self):
return bool(self.all_assignments)
class SymbolGen:
"""Default symbol generator producing number symbols ζ_0, ζ_1, ..."""
def __init__(self, symbol="xi", dtype=None, ctr=0):
self._ctr = ctr
self._symbol = symbol
self._dtype = dtype
def __iter__(self):
return self
def __next__(self):
name = f"{self._symbol}_{self._ctr}"
self._ctr += 1
if self._dtype is not None:
return pystencils.TypedSymbol(name, self._dtype)
return sp.Symbol(name)