Skip to content
Snippets Groups Projects
Commit 1c57a059 authored by Stephan Seitz's avatar Stephan Seitz
Browse files

Assert same length when performing vector assignment

parent 38da1c39
No related branches found
No related tags found
1 merge request!133Allow vector assignments
...@@ -26,7 +26,8 @@ if Assignment: ...@@ -26,7 +26,8 @@ if Assignment:
_old_new = sp.codegen.ast.Assignment.__new__ _old_new = sp.codegen.ast.Assignment.__new__
def _Assignment__new__(cls, lhs, rhs, *args, **kwargs): def _Assignment__new__(cls, lhs, rhs, *args, **kwargs):
if isinstance(lhs, (list, set, tuple, sp.Matrix)) and isinstance(rhs, (list, set, tuple, sp.Matrix)): if isinstance(lhs, (list, tuple, sp.Matrix)) and isinstance(rhs, (list, tuple, sp.Matrix)):
assert len(lhs) == len(rhs), f'{lhs} and {rhs} must have same length when performing vector assignment!'
return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs)) return tuple(_old_new(cls, a, b, *args, **kwargs) for a, b in zip(lhs, rhs))
return _old_new(cls, lhs, rhs, *args, **kwargs) return _old_new(cls, lhs, rhs, *args, **kwargs)
......
import pytest
import sympy as sp import sympy as sp
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
...@@ -52,6 +53,18 @@ def test_vector_assignments(): ...@@ -52,6 +53,18 @@ def test_vector_assignments():
print(assignments) print(assignments)
def test_wrong_vector_assignments():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
import pystencils as ps
import sympy as sp
a, b = sp.symbols("a b")
with pytest.raises(AssertionError,
match=r'Matrix(.*) and Matrix(.*) must have same length when performing vector assignment!'):
ps.Assignment(sp.Matrix([a,b]), sp.Matrix([1,2,3]))
def test_vector_assignment_collection(): def test_vector_assignment_collection():
"""From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)""" """From #17 (https://i10git.cs.fau.de/pycodegen/pystencils/issues/17)"""
...@@ -64,4 +77,3 @@ def test_vector_assignment_collection(): ...@@ -64,4 +77,3 @@ def test_vector_assignment_collection():
assignments = ps.AssignmentCollection([ps.Assignment(y,x)]) assignments = ps.AssignmentCollection([ps.Assignment(y,x)])
print(assignments) print(assignments)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment