Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Showing
with 464 additions and 1192 deletions
#ifndef OPENCL_STDINT
#define OPENCL_STDINT
typedef unsigned int uint_t;
typedef signed char int8_t;
typedef signed short int16_t;
typedef signed int int32_t;
typedef signed long int int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long int uint64_t;
#endif
from typing import List, Union
import sympy
import sympy as sp
from sympy.codegen import Assignment
from sympy.codegen.rewriting import ReplaceOptim, optimize
from pystencils.astnodes import Block, Node, SympyAssignment
from pystencils.backends.cbackend import CustomCodeNode
from pystencils.functions import DivFunc
from pystencils.simp import AssignmentCollection
class NodeCollection:
def __init__(self, assignments: List[Union[Node, Assignment]]):
self.all_assignments = assignments
if all((isinstance(a, Assignment) for a in assignments)):
self.is_Nodes = False
self.is_Assignments = True
elif all((isinstance(n, Node) for n in assignments)):
self.is_Nodes = True
self.is_Assignments = False
else:
raise ValueError(f'The list "{assignments}" is mixed. Pass either a list of "pystencils.Assignments" '
f'or a list of "pystencils.astnodes.Node')
self.simplification_hints = {}
@staticmethod
def from_assignment_collection(assignment_collection: AssignmentCollection):
nodes = list()
for assignemt in assignment_collection.all_assignments:
if isinstance(assignemt, Assignment):
nodes.append(SympyAssignment(assignemt.lhs, assignemt.rhs))
elif isinstance(assignemt, Node):
nodes.append(assignemt)
else:
raise ValueError(f"Unknown node in the AssignmentCollection: {assignemt}")
return NodeCollection(nodes)
def evaluate_terms(self):
evaluate_constant_terms = ReplaceOptim(
lambda e: hasattr(e, 'is_constant') and e.is_constant and not e.is_integer,
lambda p: p.evalf()
)
evaluate_pow = ReplaceOptim(
lambda e: e.is_Pow and e.exp.is_Integer and abs(e.exp) <= 8,
lambda p: sp.UnevaluatedExpr(sp.Mul(*([p.base] * +p.exp), evaluate=False)) if p.exp > 0 else
(DivFunc(sp.Integer(1), p.base) if p.exp == -1 else
DivFunc(sp.Integer(1), sp.UnevaluatedExpr(sp.Mul(*([p.base] * -p.exp), evaluate=False))))
)
sympy_optimisations = [evaluate_constant_terms, evaluate_pow]
if self.is_Nodes:
def visitor(node):
if isinstance(node, CustomCodeNode):
return node
elif isinstance(node, Block):
return node.func([visitor(child) for child in node.args])
elif isinstance(node, Node):
return node.func(*[visitor(child) for child in node.args])
elif isinstance(node, sympy.Basic):
return optimize(node, sympy_optimisations)
else:
raise NotImplementedError(f'{node} {type(node)} has no valid visitor')
self.all_assignments = [visitor(assignment) for assignment in self.all_assignments]
else:
self.all_assignments = [Assignment(a.lhs, optimize(a.rhs, sympy_optimisations))
if hasattr(a, 'lhs')
else a for a in self.all_assignments]
import numpy as np
import pystencils as ps
from pystencils import Assignment, Field, CreateKernelConfig, create_kernel, Target
def test_indexed_kernel():
arr = np.zeros((3, 4))
dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)])
index_arr = np.zeros((3,), dtype=dtype)
index_arr[0] = (0, 2, 3.0)
index_arr[1] = (1, 3, 42.0)
index_arr[2] = (2, 1, 5.0)
indexed_field = Field.create_from_numpy_array('index', index_arr)
normal_field = Field.create_from_numpy_array('f', arr)
update_rule = Assignment(normal_field[0, 0], indexed_field('value'))
config = CreateKernelConfig(index_fields=[indexed_field])
ast = create_kernel([update_rule], config=config)
kernel = ast.compile()
kernel(f=arr, index=index_arr)
code = ps.get_code_str(kernel)
for i in range(index_arr.shape[0]):
np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
def test_indexed_cuda_kernel():
try:
import pycuda
except ImportError:
pycuda = None
if pycuda:
import pycuda.gpuarray as gpuarray
arr = np.zeros((3, 4))
dtype = np.dtype([('x', int), ('y', int), ('value', arr.dtype)])
index_arr = np.zeros((3,), dtype=dtype)
index_arr[0] = (0, 2, 3.0)
index_arr[1] = (1, 3, 42.0)
index_arr[2] = (2, 1, 5.0)
indexed_field = Field.create_from_numpy_array('index', index_arr)
normal_field = Field.create_from_numpy_array('f', arr)
update_rule = Assignment(normal_field[0, 0], indexed_field('value'))
config = CreateKernelConfig(target=Target.GPU, index_fields=[indexed_field])
ast = create_kernel([update_rule], config=config)
kernel = ast.compile()
gpu_arr = gpuarray.to_gpu(arr)
gpu_index_arr = gpuarray.to_gpu(index_arr)
kernel(f=gpu_arr, index=gpu_index_arr)
gpu_arr.get(arr)
for i in range(index_arr.shape[0]):
np.testing.assert_allclose(arr[index_arr[i]['x'], index_arr[i]['y']], index_arr[i]['value'], atol=1e-13)
else:
print("Did not run test on GPU since no pycuda is available")
This diff is collapsed.
This diff is collapsed.
import pytest
import pystencils
from sympy import oo
@pytest.mark.parametrize('type', ('float32', 'float64', 'int64'))
@pytest.mark.parametrize('negative', (False, 'Negative'))
@pytest.mark.parametrize('target', (pystencils.Target.CPU, pystencils.Target.GPU))
def test_print_infinity(type, negative, target):
x = pystencils.fields(f'x: {type}[1d]')
if negative:
assignment = pystencils.Assignment(x.center, -oo)
else:
assignment = pystencils.Assignment(x.center, oo)
ast = pystencils.create_kernel(assignment, data_type=type, target=target)
if target == pystencils.Target.GPU:
pytest.importorskip('pycuda')
ast.compile()
print(ast.compile().code)
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import pytest
import pystencils
from pystencils.backends.cbackend import CBackend
class UnsupportedNode(pystencils.astnodes.Node):
def __init__(self):
super().__init__()
def test_print_unsupported_node():
with pytest.raises(NotImplementedError, match='CBackend does not support node of type UnsupportedNode'):
CBackend()(UnsupportedNode())
import numpy as np
import sympy as sp
from pystencils import Assignment, Field, TypedSymbol, create_kernel, make_slice
from pystencils.simp import sympy_cse_on_assignment_list
def test_sliced_iteration():
size = (4, 4)
src_arr = np.ones(size)
dst_arr = np.zeros_like(src_arr)
src_field = Field.create_from_numpy_array('src', src_arr)
dst_field = Field.create_from_numpy_array('dst', dst_arr)
a, b = sp.symbols("a b")
update_rule = Assignment(dst_field[0, 0],
(a * src_field[0, 1] + a * src_field[0, -1] +
b * src_field[1, 0] + b * src_field[-1, 0]) / 4)
x_end = TypedSymbol("x_end", "int")
s = make_slice[1:x_end, 1]
x_end_value = size[1] - 1
kernel = create_kernel(sympy_cse_on_assignment_list([update_rule]), iteration_slice=s).compile()
kernel(src=src_arr, dst=dst_arr, a=1.0, b=1.0, x_end=x_end_value)
expected_result = np.zeros(size)
expected_result[1:x_end_value, 1] = 1
np.testing.assert_almost_equal(expected_result, dst_arr)
[pytest]
testpaths = src tests doc/notebooks
pythonpath = src
python_files = test_*.py *_test.py scenario_*.py
norecursedirs = *.egg-info .git .cache .ipynb_checkpoints htmlcov
addopts = --doctest-modules --durations=20 --cov-config pytest.ini
......@@ -17,20 +19,21 @@ filterwarnings =
[run]
branch = True
source = pystencils
pystencils_tests
source = src/pystencils
tests
omit = doc/*
pystencils_tests/*
tests/*
setup.py
quicktest.py
conftest.py
versioneer.py
pystencils/jupytersetup.py
pystencils/cpu/msvc_detection.py
pystencils/sympy_gmpy_bug_workaround.py
pystencils/cache.py
pystencils/pacxx/benchmark.py
pystencils/_version.py
src/pystencils/jupytersetup.py
src/pystencils/cpu/msvc_detection.py
src/pystencils/sympy_gmpy_bug_workaround.py
src/pystencils/cache.py
src/pystencils/pacxx/benchmark.py
src/pystencils/_version.py
venv/
[report]
......
#!/usr/bin/env python3
from contextlib import redirect_stdout
import io
from tests.test_quicktests import (
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
)
quick_tests = [
test_basic_kernel,
test_basic_blocking_staggered,
test_basic_vectorization,
]
if __name__ == "__main__":
print("Running pystencils quicktests")
for qt in quick_tests:
print(f" -> {qt.__name__}")
with redirect_stdout(io.StringIO()):
qt()
# See the docstring in versioneer.py for instructions. Note that you must
# re-run 'versioneer.py setup' after changing this section, and commit the
# resulting files.
[versioneer]
VCS = git
style = pep440
versionfile_source = pystencils/_version.py
versionfile_build = pystencils/_version.py
tag_prefix = release/
parentdir_prefix = pystencils-
import distutils
import io
import os
from contextlib import redirect_stdout
from importlib import import_module
from setuptools import setup, __version__ as setuptools_version
import setuptools
if int(setuptools_version.split('.')[0]) < 61:
raise Exception(
"[ERROR] pystencils requires at least setuptools version 61 to install.\n"
"If this error occurs during an installation via pip, it is likely that there is a conflict between "
"versions of setuptools installed by pip and the system package manager. "
"In this case, it is recommended to install pystencils into a virtual environment instead."
)
import versioneer
try:
import cython # noqa
USE_CYTHON = True
except ImportError:
USE_CYTHON = False
quick_tests = [
'test_quicktests.test_basic_kernel',
'test_quicktests.test_basic_blocking_staggered',
'test_quicktests.test_basic_vectorization',
]
class SimpleTestRunner(distutils.cmd.Command):
"""A custom command to run selected tests"""
description = 'run some quick tests'
user_options = []
@staticmethod
def _run_tests_in_module(test):
"""Short test runner function - to work also if py.test is not installed."""
test = f'pystencils_tests.{test}'
mod, function_name = test.rsplit('.', 1)
if isinstance(mod, str):
mod = import_module(mod)
func = getattr(mod, function_name)
print(f" -> {function_name} in {mod.__name__}")
with redirect_stdout(io.StringIO()):
func()
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
"""Run command."""
for test in quick_tests:
self._run_tests_in_module(test)
def readme():
with open('README.md') as f:
return f.read()
def cython_extensions(*extensions):
from distutils.extension import Extension
if USE_CYTHON:
ext = '.pyx'
result = [Extension(e, [os.path.join(*e.split(".")) + ext]) for e in extensions]
from Cython.Build import cythonize
result = cythonize(result, language_level=3)
return result
elif all([os.path.exists(os.path.join(*e.split(".")) + '.c') for e in extensions]):
ext = '.c'
result = [Extension(e, [os.path.join(*e.split(".")) + ext]) for e in extensions]
return result
else:
return None
def get_cmdclass():
cmdclass = {"quicktest": SimpleTestRunner}
cmdclass.update(versioneer.get_cmdclass())
return cmdclass
return versioneer.get_cmdclass()
setuptools.setup(name='pystencils',
description='Speeding up stencil computations on CPUs and GPUs',
version=versioneer.get_version(),
long_description=readme(),
long_description_content_type="text/markdown",
author='Martin Bauer, Jan Hönig, Markus Holzer',
license='AGPLv3',
author_email='cs10-codegen@fau.de',
url='https://i10git.cs.fau.de/pycodegen/pystencils/',
packages=['pystencils'] + ['pystencils.' + s for s in setuptools.find_packages('pystencils')],
install_requires=['sympy>=1.6,<=1.11.1', 'numpy>=1.8.0', 'appdirs', 'joblib'],
package_data={'pystencils': ['include/*.h',
'backends/cuda_known_functions.txt',
'backends/opencl1.1_known_functions.txt',
'boundaries/createindexlistcython.c',
'boundaries/createindexlistcython.pyx']},
ext_modules=cython_extensions("pystencils.boundaries.createindexlistcython"),
classifiers=[
'Development Status :: 4 - Beta',
'Framework :: Jupyter',
'Topic :: Software Development :: Code Generators',
'Topic :: Scientific/Engineering :: Physics',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU Affero General Public License v3 or later (AGPLv3+)',
],
project_urls={
"Bug Tracker": "https://i10git.cs.fau.de/pycodegen/pystencils/-/issues",
"Documentation": "https://pycodegen.pages.i10git.cs.fau.de/pystencils/",
"Source Code": "https://i10git.cs.fau.de/pycodegen/pystencils",
},
extras_require={
'gpu': ['pycuda'],
'alltrafos': ['islpy', 'py-cpuinfo'],
'bench_db': ['blitzdb', 'pymongo', 'pandas'],
'interactive': ['matplotlib', 'ipy_table', 'imageio', 'jupyter', 'pyevtk', 'rich', 'graphviz'],
'doc': ['sphinx', 'sphinx_rtd_theme', 'nbsphinx',
'sphinxcontrib-bibtex', 'sphinx_autodoc_typehints', 'pandoc'],
'use_cython': ['Cython']
},
tests_require=['pytest',
'pytest-cov',
'pytest-html',
'ansi2html',
'pytest-xdist',
'flake8',
'nbformat',
'nbconvert',
'ipython',
'randomgen>=1.18'],
python_requires=">=3.8",
cmdclass=get_cmdclass()
)
setup(
version=versioneer.get_version(),
cmdclass=get_cmdclass(),
)
......@@ -2,9 +2,8 @@
from .enums import Backend, Target
from . import fd
from . import stencil as stencil
from .assignment import Assignment, assignment_from_stencil
from pystencils.typing.typed_sympy import TypedSymbol
from .datahandling import create_data_handling
from .assignment import Assignment, AddAugmentedAssignment, assignment_from_stencil
from .typing.typed_sympy import TypedSymbol
from .display_utils import get_code_obj, get_code_str, show_code, to_dot
from .field import Field, FieldType, fields
from .config import CreateKernelConfig
......@@ -15,6 +14,7 @@ from .simp import AssignmentCollection
from .slicing import make_slice
from .spatial_coordinates import x_, x_staggered, x_staggered_vector, x_vector, y_, y_staggered, z_, z_staggered
from .sympyextensions import SymbolCreator
from .datahandling import create_data_handling
__all__ = ['Field', 'FieldType', 'fields',
'TypedSymbol',
......@@ -24,7 +24,7 @@ __all__ = ['Field', 'FieldType', 'fields',
'Target', 'Backend',
'show_code', 'to_dot', 'get_code_obj', 'get_code_str',
'AssignmentCollection',
'Assignment',
'Assignment', 'AddAugmentedAssignment',
'assignment_from_stencil',
'SymbolCreator',
'create_data_handling',
......@@ -36,7 +36,5 @@ __all__ = ['Field', 'FieldType', 'fields',
'fd',
'stencil']
from ._version import get_versions
__version__ = get_versions()['version']
del get_versions
from . import _version
__version__ = _version.get_versions()['version']
......@@ -25,7 +25,8 @@ def aligned_empty(shape, byte_alignment=True, dtype=np.float64, byte_offset=0, o
byte_alignment = 64
elif byte_alignment == 'cacheline':
cacheline_sizes = [get_cacheline_size(is_name) for is_name in instruction_sets]
if all([s is None for s in cacheline_sizes]):
if all([s is None for s in cacheline_sizes]) or \
max([s for s in cacheline_sizes if s is not None]) > 0x100000:
widths = [get_vector_instruction_set(dtype, is_name)['width'] * np.dtype(dtype).itemsize
for is_name in instruction_sets
if type(get_vector_instruction_set(dtype, is_name)['width']) is int]
......
import numpy as np
import sympy as sp
from sympy.codegen.ast import Assignment
from sympy.codegen.ast import Assignment, AugmentedAssignment, AddAugmentedAssignment
from sympy.printing.latex import LatexPrinter
__all__ = ['Assignment', 'assignment_from_stencil']
__all__ = ['Assignment', 'AugmentedAssignment', 'AddAugmentedAssignment', 'assignment_from_stencil']
def print_assignment_latex(printer, expr):
binop = f"{expr.binop}=" if isinstance(expr, AugmentedAssignment) else ''
"""sympy cannot print Assignments as Latex. Thus, this function is added to the sympy Latex printer"""
printed_lhs = printer.doprint(expr.lhs)
printed_rhs = printer.doprint(expr.rhs)
return fr"{printed_lhs} \leftarrow {printed_rhs}"
return fr"{printed_lhs} \leftarrow_{{{binop}}} {printed_rhs}"
def assignment_str(assignment):
return fr"{assignment.lhs}{assignment.rhs}"
op = f"{assignment.binop}=" if isinstance(assignment, AugmentedAssignment) else ''
return fr"{assignment.lhs} {op} {assignment.rhs}"
_old_new = sp.codegen.ast.Assignment.__new__
......@@ -32,6 +34,9 @@ Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex
AugmentedAssignment.__str__ = assignment_str
LatexPrinter._print_AugmentedAssignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
......
......@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.typing.utilities import create_type, get_next_parent_of_type
from pystencils.assignment import Assignment
from pystencils.enums import Target, Backend
from pystencils.field import Field
from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.typing import (create_type, get_next_parent_of_type,
FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction)
NodeOrExpr = Union['Node', sp.Expr]
......@@ -270,6 +270,9 @@ class KernelFunction(Node):
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
# Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
# by including a respective header file in the compute kernel. Hence, it is not a free parameter.
parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
parameters.sort(key=lambda p: p.symbol.name)
return parameters
......@@ -303,7 +306,7 @@ class SkipIteration(Node):
class Block(Node):
def __init__(self, nodes: List[Node]):
def __init__(self, nodes: Union[Node, List[Node]]):
super(Block, self).__init__()
if not isinstance(nodes, list):
nodes = [nodes]
......@@ -345,14 +348,6 @@ class Block(Node):
assert self._nodes.count(insert_before) == 1
idx = self._nodes.index(insert_before)
# move all assignment (definitions to the top)
if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
while idx > 0:
pn = self._nodes[idx - 1]
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
idx -= 1
else:
break
if not if_not_exists or self._nodes[idx] != new_node:
self._nodes.insert(idx, new_node)
......@@ -361,14 +356,6 @@ class Block(Node):
assert self._nodes.count(insert_after) == 1
idx = self._nodes.index(insert_after) + 1
# move all assignment (definitions to the top)
if isinstance(new_node, SympyAssignment) and new_node.is_declaration:
while idx > 0:
pn = self._nodes[idx - 1]
if isinstance(pn, LoopOverCoordinate) or isinstance(pn, Conditional):
idx -= 1
else:
break
if not if_not_exists or not (self._nodes[idx - 1] == new_node
or (idx < len(self._nodes) and self._nodes[idx] == new_node)):
self._nodes.insert(idx, new_node)
......@@ -403,7 +390,7 @@ class Block(Node):
def symbols_defined(self):
result = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
else:
result.update(a.symbols_defined)
......@@ -414,7 +401,7 @@ class Block(Node):
result = set()
defined_symbols = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
defined_symbols.update({a.lhs})
else:
......@@ -444,7 +431,7 @@ class LoopOverCoordinate(Node):
LOOP_COUNTER_NAME_PREFIX = "ctr"
BLOCK_LOOP_COUNTER_NAME_PREFIX = "_blockctr"
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False):
def __init__(self, body, coordinate_to_loop_over, start, stop, step=1, is_block_loop=False, custom_loop_ctr=None):
super(LoopOverCoordinate, self).__init__(parent=None)
self.body = body
body.parent = self
......@@ -455,10 +442,11 @@ class LoopOverCoordinate(Node):
self.body.parent = self
self.prefix_lines = []
self.is_block_loop = is_block_loop
self.custom_loop_ctr = custom_loop_ctr
def new_loop_with_different_body(self, new_body):
result = LoopOverCoordinate(new_body, self.coordinate_to_loop_over, self.start, self.stop,
self.step, self.is_block_loop)
self.step, self.is_block_loop, self.custom_loop_ctr)
result.prefix_lines = [prefix_line for prefix_line in self.prefix_lines]
return result
......@@ -521,10 +509,13 @@ class LoopOverCoordinate(Node):
@property
def loop_counter_name(self):
if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
if self.custom_loop_ctr:
return self.custom_loop_ctr.name
else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
if self.is_block_loop:
return LoopOverCoordinate.get_block_loop_counter_name(self.coordinate_to_loop_over)
else:
return LoopOverCoordinate.get_loop_counter_name(self.coordinate_to_loop_over)
@staticmethod
def is_loop_counter_symbol(symbol):
......@@ -548,10 +539,13 @@ class LoopOverCoordinate(Node):
@property
def loop_counter_symbol(self):
if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
if self.custom_loop_ctr:
return self.custom_loop_ctr
else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
if self.is_block_loop:
return self.get_block_loop_counter_symbol(self.coordinate_to_loop_over)
else:
return self.get_loop_counter_symbol(self.coordinate_to_loop_over)
@property
def is_outermost_loop(self):
......@@ -577,10 +571,10 @@ class SympyAssignment(Node):
def __init__(self, lhs_symbol, rhs_expr, is_const=True, use_auto=False):
super(SympyAssignment, self).__init__(parent=None)
self._lhs_symbol = sp.sympify(lhs_symbol)
self.rhs = sp.sympify(rhs_expr)
self._rhs = sp.sympify(rhs_expr)
self._is_const = is_const
self._is_declaration = self.__is_declaration()
self.use_auto = use_auto
self._use_auto = use_auto
def __is_declaration(self):
from pystencils.typing import CastFunc
......@@ -594,15 +588,28 @@ class SympyAssignment(Node):
def lhs(self):
return self._lhs_symbol
@property
def rhs(self):
return self._rhs
@lhs.setter
def lhs(self, new_value):
self._lhs_symbol = new_value
self._is_declaration = self.__is_declaration()
@rhs.setter
def rhs(self, new_rhs_expr):
self._rhs = new_rhs_expr
def subs(self, subs_dict):
self.lhs = fast_subs(self.lhs, subs_dict)
self.rhs = fast_subs(self.rhs, subs_dict)
def fast_subs(self, subs_dict, skip=None):
self.lhs = fast_subs(self.lhs, subs_dict, skip)
self.rhs = fast_subs(self.rhs, subs_dict, skip)
return self
def optimize(self, optimizations):
try:
from sympy.codegen.rewriting import optimize
......@@ -612,7 +619,7 @@ class SympyAssignment(Node):
@property
def args(self):
return [self._lhs_symbol, self.rhs, sp.sympify(self._is_const)]
return [self._lhs_symbol, self.rhs]
@property
def symbols_defined(self):
......@@ -643,6 +650,10 @@ class SympyAssignment(Node):
def is_const(self):
return self._is_const
@property
def use_auto(self):
return self._use_auto
def replace(self, child, replacement):
if child == self.lhs:
replacement.parent = self
......@@ -665,7 +676,7 @@ class SympyAssignment(Node):
return hash((self.lhs, self.rhs))
def __eq__(self, other):
return type(self) == type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
return type(self) is type(other) and (self.lhs, self.rhs) == (other.lhs, other.rhs)
class ResolvedFieldAccess(sp.Indexed):
......