Skip to content
Snippets Groups Projects
Commit 46b7b88d authored by Martin Bauer's avatar Martin Bauer
Browse files

Merge branch 'vector-assignments' into 'master'

Allow vector assignments

Closes #17

See merge request !133
parents 75b3dad8 1c57a059
Branches
Tags
1 merge request!133Allow vector assignments
Pipeline #21249 passed
# -*- coding: utf-8 -*-
import numpy as np import numpy as np
import sympy as sp import sympy as sp
from sympy.printing.latex import LatexPrinter from sympy.printing.latex import LatexPrinter
...@@ -24,9 +23,20 @@ def assignment_str(assignment): ...@@ -24,9 +23,20 @@ def assignment_str(assignment):
if Assignment: if Assignment:
_old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs)
Assignment.__str__ = assignment_str Assignment.__str__ = assignment_str
Assignment.__new__ = _Assignment__new__
LatexPrinter._print_Assignment = print_assignment_latex LatexPrinter._print_Assignment = print_assignment_latex
sp.MutableDenseMatrix.__hash__ = lambda self: hash(tuple(self))
else: else:
# back port for older sympy versions that don't have Assignment yet # back port for older sympy versions that don't have Assignment yet
......
import itertools
from copy import copy from copy import copy
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Union
...@@ -43,6 +44,11 @@ class AssignmentCollection: ...@@ -43,6 +44,11 @@ class AssignmentCollection:
subexpressions = [Assignment(k, v) subexpressions = [Assignment(k, v)
for k, v in subexpressions.items()] for k, v in subexpressions.items()]
main_assignments = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in main_assignments]))
subexpressions = list(itertools.chain.from_iterable(
[(a if isinstance(a, Iterable) else [a]) for a in subexpressions]))
self.main_assignments = main_assignments self.main_assignments = main_assignments
self.subexpressions = subexpressions self.subexpressions = subexpressions
......
import pytest
import sympy as sp import sympy as sp
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
...@@ -40,3 +41,39 @@ def test_free_and_defined_symbols(): ...@@ -40,3 +41,39 @@ def test_free_and_defined_symbols():
print(ac) print(ac)
print(ac.__repr__) print(ac.__repr__)
def test_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
assignments = ps.Assignment(sp.Matrix([a,b,c]), sp.Matrix([1,2,3]))
print(assignments)
def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b = sp.symbols("a b")
with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3]))
def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b, c = sp.symbols("a b c")
y, x = sp.Matrix([a,b,c]), sp.Matrix([1,2,3])
assignments = ps.AssignmentCollection({y: x})
print(assignments)
assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
print(assignments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment