Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision

Target

Select target project
  • anirudh.jonnalagadda/pystencils
  • hyteg/pystencils
  • jbadwaik/pystencils
  • jngrad/pystencils
  • itischler/pystencils
  • ob28imeq/pystencils
  • hoenig/pystencils
  • Bindgen/pystencils
  • hammer/pystencils
  • da15siwa/pystencils
  • holzer/pystencils
  • alexander.reinauer/pystencils
  • ec93ujoh/pystencils
  • Harke/pystencils
  • seitz/pystencils
  • pycodegen/pystencils
16 results
Select Git revision
Show changes
Showing
with 692 additions and 121 deletions
"""Fixtures for the pystencils test suite
This module provides a number of fixtures used by the pystencils test suite.
Use these fixtures wherever applicable to extend the code surface area covered
by your tests:
- All tests that should work for every target should use the `target` fixture
- All tests that should work with the highest optimization level for every target
should use the `gen_config` fixture
- Use the `xp` fixture to access the correct array module (numpy or cupy) depending
on the target
"""
import pytest
from types import ModuleType
import pystencils as ps
AVAILABLE_TARGETS = [ps.Target.GenericCPU]
try:
import cupy
AVAILABLE_TARGETS += [ps.Target.CUDA]
except ImportError:
pass
AVAILABLE_TARGETS += ps.Target.available_vector_cpu_targets()
TARGET_IDS = [t.name for t in AVAILABLE_TARGETS]
def pytest_addoption(parser: pytest.Parser):
parser.addoption(
"--experimental-cpu-jit",
dest="experimental_cpu_jit",
action="store_true"
)
@pytest.fixture(params=AVAILABLE_TARGETS, ids=TARGET_IDS)
def target(request) -> ps.Target:
"""Provides all code generation targets available on the current hardware"""
return request.param
@pytest.fixture
def gen_config(request: pytest.FixtureRequest, target: ps.Target):
"""Default codegen configuration for the current target.
For GPU targets, set default indexing options.
For vector-CPU targets, set default vectorization config.
"""
gen_config = ps.CreateKernelConfig(target=target)
if target.is_vector_cpu():
gen_config.cpu.vectorize.enable = True
gen_config.cpu.vectorize.assume_inner_stride_one = True
if target.is_cpu() and request.config.getoption("experimental_cpu_jit"):
from pystencils.jit.cpu import CpuJit, GccInfo
gen_config.jit = CpuJit.create(compiler_info=GccInfo(target=target))
return gen_config
@pytest.fixture()
def xp(target: ps.Target) -> ModuleType:
"""Primary array module for the current target.
Returns:
`cupy` if `target == Target.CUDA`, and `numpy` otherwise
"""
if target == ps.Target.CUDA:
import cupy as xp
return xp
else:
import numpy as np
return np
...@@ -18,7 +18,7 @@ def test_max(dtype, sympy_function): ...@@ -18,7 +18,7 @@ def test_max(dtype, sympy_function):
z = dh.add_array('z', values_per_cell=1, dtype=dtype) z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2.0, ghost_layers=True) dh.fill("z", 2.0, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_float=dtype) config = pystencils.CreateKernelConfig(default_dtype=dtype)
# test sp.Max with one argument # test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3)) assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3.3))
...@@ -63,7 +63,7 @@ def test_max_integer(dtype, sympy_function): ...@@ -63,7 +63,7 @@ def test_max_integer(dtype, sympy_function):
z = dh.add_array('z', values_per_cell=1, dtype=dtype) z = dh.add_array('z', values_per_cell=1, dtype=dtype)
dh.fill("z", 2, ghost_layers=True) dh.fill("z", 2, ghost_layers=True)
config = pystencils.CreateKernelConfig(default_number_int=dtype) config = pystencils.CreateKernelConfig(default_dtype=dtype)
# test sp.Max with one argument # test sp.Max with one argument
assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3)) assignment_1 = pystencils.Assignment(x.center, sympy_function(y.center + 3))
......
import pytest import pytest
import pystencils.config
import sympy
import pystencils as ps import pystencils as ps
from pystencils.typing import CastFunc, create_type import sympy
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
def test_abs(target): def test_abs(target):
x, y, z = ps.fields('x, y, z: float64[2d]') if target == ps.Target.GPU:
# FIXME
pytest.xfail("GPU target not ready yet")
default_int_type = create_type('int64') x, y, z = ps.fields('x, y, z: int64[2d]')
assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(CastFunc(y[0, 0], default_int_type))}) assignments = ps.AssignmentCollection({x[0, 0]: sympy.Abs(y[0, 0])})
config = pystencils.config.CreateKernelConfig(target=target) config = ps.CreateKernelConfig(target=target)
ast = ps.create_kernel(assignments, config=config) ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
print(code) print(code)
......
...@@ -3,35 +3,37 @@ Test of pystencils.data_types.address_of ...@@ -3,35 +3,37 @@ Test of pystencils.data_types.address_of
""" """
import pytest import pytest
import pystencils import pystencils
from pystencils.typing import PointerType, CastFunc, BasicType from pystencils.types import PsPointerType, create_type
from pystencils.functions import AddressOf from pystencils.sympyextensions.pointers import AddressOf
from pystencils.simp.simplifications import sympy_cse from pystencils.sympyextensions.typed_sympy import tcast
from pystencils.simp import sympy_cse
import sympy as sp import sympy as sp
def test_address_of(): def test_address_of():
x, y = pystencils.fields('x, y: int64[2d]') x, y = pystencils.fields('x, y: int64[2d]')
s = pystencils.TypedSymbol('s', PointerType(BasicType('int64'))) s = pystencils.TypedSymbol('s', PsPointerType(create_type('int64')))
assert AddressOf(x[0, 0]).canonical() == x[0, 0] assert AddressOf(x[0, 0]).canonical() == x[0, 0]
assert AddressOf(x[0, 0]).dtype == PointerType(x[0, 0].dtype, restrict=True) assert AddressOf(x[0, 0]).dtype == PsPointerType(x[0, 0].dtype, restrict=True, const=True)
with pytest.raises(ValueError): with pytest.raises(ValueError):
assert AddressOf(sp.Symbol("a")).dtype assert AddressOf(sp.Symbol("a")).dtype
assignments = pystencils.AssignmentCollection({ assignments = pystencils.AssignmentCollection({
s: AddressOf(x[0, 0]), s: AddressOf(x[0, 0]),
y[0, 0]: CastFunc(s, BasicType('int64')) y[0, 0]: tcast(s, create_type('int64'))
}) })
kernel = pystencils.create_kernel(assignments).compile() _ = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast) # pystencils.show_code(kernel.ast)
assignments = pystencils.AssignmentCollection({ assignments = pystencils.AssignmentCollection({
y[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) y[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64'))
}) })
kernel = pystencils.create_kernel(assignments).compile() _ = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast) # pystencils.show_code(kernel.ast)
...@@ -39,12 +41,12 @@ def test_address_of_with_cse(): ...@@ -39,12 +41,12 @@ def test_address_of_with_cse():
x, y = pystencils.fields('x, y: int64[2d]') x, y = pystencils.fields('x, y: int64[2d]')
assignments = pystencils.AssignmentCollection({ assignments = pystencils.AssignmentCollection({
x[0, 0]: CastFunc(AddressOf(x[0, 0]), BasicType('int64')) + 1 x[0, 0]: tcast(AddressOf(x[0, 0]), create_type('int64')) + 1
}) })
kernel = pystencils.create_kernel(assignments).compile() _ = pystencils.create_kernel(assignments).compile()
# pystencils.show_code(kernel.ast) # pystencils.show_code(kernel.ast)
assignments_cse = sympy_cse(assignments) assignments_cse = sympy_cse(assignments)
kernel = pystencils.create_kernel(assignments_cse).compile() _ = pystencils.create_kernel(assignments_cse).compile()
# pystencils.show_code(kernel.ast) # pystencils.show_code(kernel.ast)
...@@ -3,7 +3,6 @@ import sympy as sp ...@@ -3,7 +3,6 @@ import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
from pystencils.astnodes import Conditional
from pystencils.simp.assignment_collection import SymbolGen from pystencils.simp.assignment_collection import SymbolGen
a, b, c = sp.symbols("a b c") a, b, c = sp.symbols("a b c")
...@@ -35,12 +34,12 @@ def test_assignment_collection(): ...@@ -35,12 +34,12 @@ def test_assignment_collection():
assert '<table' in ac_inserted._repr_html_() assert '<table' in ac_inserted._repr_html_()
def test_free_and_defined_symbols(): # def test_free_and_defined_symbols():
ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))], # ac = AssignmentCollection([Assignment(z, x + y), Conditional(t > 0, Assignment(a, b+1), Assignment(a, b+2))],
[], subexpression_symbol_generator=symbol_gen) # [], subexpression_symbol_generator=symbol_gen)
print(ac) # print(ac)
print(ac.__repr__) # print(ac.__repr__)
def test_vector_assignments(): def test_vector_assignments():
...@@ -170,3 +169,50 @@ def test_new_merged(): ...@@ -170,3 +169,50 @@ def test_new_merged():
assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments assert ps.Assignment(d[0, 0](0), sp.symbols('xi_0')) in merged_ac.main_assignments
assert a1 in merged_ac.subexpressions assert a1 in merged_ac.subexpressions
assert a3 in merged_ac.subexpressions assert a3 in merged_ac.subexpressions
a1 = ps.Assignment(a, 20)
a2 = ps.Assignment(a, 10)
acommon = ps.Assignment(b, a)
# main assignments
a3 = ps.Assignment(f[0, 0](0), b)
a4 = ps.Assignment(d[0, 0](0), b)
ac = ps.AssignmentCollection([a3], subexpressions=[a1, acommon])
ac2 = ps.AssignmentCollection([a4], subexpressions=[a2, acommon])
merged_ac = ac.new_merged(ac2).new_without_subexpressions()
assert ps.Assignment(f[0, 0](0), 20) in merged_ac.main_assignments
assert ps.Assignment(d[0, 0](0), 10) in merged_ac.main_assignments
def test_assignment_collection_dict_conversion():
x, y = ps.fields('x,y: [2D]')
collection_normal = ps.AssignmentCollection(
[ps.Assignment(x.center(), y[1, 0] + y[0, 0])],
[]
)
collection_dict = ps.AssignmentCollection(
{x.center(): y[1, 0] + y[0, 0]},
{}
)
assert str(collection_normal) == str(collection_dict)
assert collection_dict.main_assignments_dict == {x.center(): y[1, 0] + y[0, 0]}
assert collection_dict.subexpressions_dict == {}
collection_normal = ps.AssignmentCollection(
[ps.Assignment(y[1, 0], x.center()),
ps.Assignment(y[0, 0], x.center())],
[]
)
collection_dict = ps.AssignmentCollection(
{y[1, 0]: x.center(),
y[0, 0]: x.center()},
{}
)
assert str(collection_normal) == str(collection_dict)
assert collection_dict.main_assignments_dict == {y[1, 0]: x.center(),
y[0, 0]: x.center()}
assert collection_dict.subexpressions_dict == {}
import numpy as np
import pytest import pytest
import pystencils as ps import pystencils as ps
from pystencils.assignment import assignment_from_stencil
def test_assignment_from_stencil():
stencil = [
[0, 0, 4, 1, 0, 0, 0],
[0, 0, 0, 2, 0, 0, 0],
[0, 0, 0, 3, 0, 0, 0]
]
x, y = ps.fields('x, y: [2D]')
assignment = assignment_from_stencil(stencil, x, y)
assert isinstance(assignment, ps.Assignment)
assert assignment.rhs == x[0, 1] + 4 * x[-1, 1] + 2 * x[0, 0] + 3 * x[0, -1]
assignment = assignment_from_stencil(stencil, x, y, normalization_factor=1 / np.sum(stencil))
assert isinstance(assignment, ps.Assignment)
@pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU]) @pytest.mark.parametrize('target', [ps.Target.CPU, ps.Target.GPU])
......
...@@ -3,11 +3,13 @@ import numpy as np ...@@ -3,11 +3,13 @@ import numpy as np
import pystencils as ps import pystencils as ps
from pystencils import Field, Assignment, create_kernel from pystencils import Field, Assignment, create_kernel
from pystencils.bit_masks import flag_cond from pystencils.sympyextensions.bit_masks import flag_cond
@pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64]) @pytest.mark.parametrize('mask_type', [np.uint8, np.uint16, np.uint32, np.uint64])
@pytest.mark.xfail(reason="Bit masks not yet supported by the new backend")
def test_flag_condition(mask_type): def test_flag_condition(mask_type):
f_arr = np.zeros((2, 2, 2), dtype=np.float64) f_arr = np.zeros((2, 2, 2), dtype=np.float64)
mask_arr = np.zeros((2, 2), dtype=mask_type) mask_arr = np.zeros((2, 2), dtype=mask_type)
......
...@@ -15,7 +15,7 @@ import sympy as sp ...@@ -15,7 +15,7 @@ import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils import Field, x_vector from pystencils import Field, x_vector
from pystencils.astnodes import ConditionalFieldAccess from pystencils.sympyextensions.astnodes import ConditionalFieldAccess
from pystencils.simp import sympy_cse from pystencils.simp import sympy_cse
...@@ -59,7 +59,7 @@ def test_boundary_check(dtype, with_cse): ...@@ -59,7 +59,7 @@ def test_boundary_check(dtype, with_cse):
assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0) config = ps.CreateKernelConfig(default_dtype=ps.create_type(dtype), ghost_layers=0)
kernel_checked = ps.create_kernel(assignments, config=config).compile() kernel_checked = ps.create_kernel(assignments, config=config).compile()
# ps.show_code(kernel_checked) # ps.show_code(kernel_checked)
......
...@@ -2,10 +2,11 @@ import pytest ...@@ -2,10 +2,11 @@ import pytest
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils.fast_approximation import ( from pystencils.sympyextensions.fast_approximation import (
fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts) fast_division, fast_inv_sqrt, fast_sqrt, insert_fast_divisions, insert_fast_sqrts)
@pytest.mark.xfail(reason="Fast approximations are not implemented yet")
def test_fast_sqrt(): def test_fast_sqrt():
pytest.importorskip('cupy') pytest.importorskip('cupy')
f, g = ps.fields("f, g: double[2D]") f, g = ps.fields("f, g: double[2D]")
...@@ -29,6 +30,7 @@ def test_fast_sqrt(): ...@@ -29,6 +30,7 @@ def test_fast_sqrt():
assert '__frsqrt_rn' in code_str assert '__frsqrt_rn' in code_str
@pytest.mark.xfail(reason="Fast approximations are not implemented yet")
def test_fast_divisions(): def test_fast_divisions():
pytest.importorskip('cupy') pytest.importorskip('cupy')
f, g = ps.fields("f, g: double[2D]") f, g = ps.fields("f, g: double[2D]")
......
import numpy as np
import pytest
import sympy as sp
import pystencils as ps
from pystencils import DEFAULTS, DynamicType, create_type, fields
from pystencils.field import (
Field,
FieldType,
layout_string_to_tuple,
spatial_layout_string_to_tuple,
)
def test_field_basic():
f = Field.create_generic("f", spatial_dimensions=2)
assert FieldType.is_generic(f)
assert f.dtype == DynamicType.NUMERIC_TYPE
assert f["E"] == f[1, 0]
assert f["N"] == f[0, 1]
assert "_" in f.center._latex("dummy")
assert (
f.index_to_physical(index_coordinates=sp.Matrix([0, 0]), staggered=False)[0]
== 0
)
assert (
f.index_to_physical(index_coordinates=sp.Matrix([0, 0]), staggered=False)[1]
== 0
)
assert (
f.physical_to_index(physical_coordinates=sp.Matrix([0, 0]), staggered=False)[0]
== 0
)
assert (
f.physical_to_index(physical_coordinates=sp.Matrix([0, 0]), staggered=False)[1]
== 0
)
f1 = f.new_field_with_different_name("f1")
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
f = Field.create_fixed_size("f", (10, 10), strides=(10, 1), dtype=np.float64)
assert f.spatial_strides == (10, 1)
assert f.index_strides == ()
assert f.center_vector == sp.Matrix([f.center])
assert f.dtype == create_type("float64")
f1 = f.new_field_with_different_name("f1")
assert f1.ndim == f.ndim
assert f1.values_per_cell() == f.values_per_cell()
assert f1.dtype == create_type("float64")
f = Field.create_fixed_size("f", (8, 8, 2, 2), index_dimensions=2)
assert f.center_vector == sp.Matrix([[f(0, 0), f(0, 1)], [f(1, 0), f(1, 1)]])
field_access = f[1, 1]
assert field_access.nr_of_coordinates == 2
assert field_access.offset_name == "NE"
neighbor = field_access.neighbor(coord_id=0, offset=-2)
assert neighbor.offsets == (-1, 1)
assert "_" in neighbor._latex("dummy")
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_fixed_size("f", (8, 8, 2, 2, 2), index_dimensions=3)
assert f.center_vector == sp.Array(
[[[f(i, j, k) for k in range(2)] for j in range(2)] for i in range(2)]
)
assert f.dtype == DynamicType.NUMERIC_TYPE
f = Field.create_generic("f", spatial_dimensions=5, index_dimensions=2)
field_access = f[1, -1, 2, -3, 0](1, 0)
assert field_access.offsets == (1, -1, 2, -3, 0)
assert field_access.index == (1, 0)
assert f.dtype == DynamicType.NUMERIC_TYPE
def test_field_description_parsing():
f, g = fields("f(1), g(3): [2D]")
assert f.dtype == g.dtype == DynamicType.NUMERIC_TYPE
assert f.spatial_dimensions == g.spatial_dimensions == 2
assert f.index_shape == (1,)
assert g.index_shape == (3,)
f = fields("f: dyn[3D]")
assert f.dtype == DynamicType.NUMERIC_TYPE
idx = fields("idx: dynidx[3D]")
assert idx.dtype == DynamicType.INDEX_TYPE
h = fields("h: float32[3D]")
assert h.index_shape == ()
assert h.spatial_dimensions == 3
assert h.index_dimensions == 0
assert h.dtype == create_type("float32")
f: Field = fields("f(5, 5) : double[20, 20]")
assert f.dtype == create_type("float64")
assert f.spatial_shape == (20, 20)
assert f.index_shape == (5, 5)
assert f.neighbor_vector((1, 1)).shape == (5, 5)
def test_error_handling():
struct_dtype = np.dtype(
[("a", np.int32), ("b", np.float64), ("c", np.uint32)], align=True
)
Field.create_generic(
"f", spatial_dimensions=2, index_dimensions=0, dtype=struct_dtype
)
with pytest.raises(ValueError) as e:
Field.create_generic(
"f", spatial_dimensions=2, index_dimensions=1, dtype=struct_dtype
)
assert "index dimension" in str(e.value)
arr = np.array([[[(1,) * 3, (2,) * 3, (3,) * 3]] * 2], dtype=struct_dtype)
Field.create_from_numpy_array("f", arr, index_dimensions=0)
with pytest.raises(ValueError) as e:
Field.create_from_numpy_array("f", arr, index_dimensions=1)
assert "Structured arrays" in str(e.value)
arr = np.zeros([3, 3, 3])
Field.create_from_numpy_array("f", arr, index_dimensions=2)
with pytest.raises(ValueError) as e:
Field.create_from_numpy_array("f", arr, index_dimensions=3)
assert "Too many" in str(e.value)
Field.create_fixed_size(
"f", (3, 2, 4), index_dimensions=0, dtype=struct_dtype, layout="reverse_numpy"
)
with pytest.raises(ValueError) as e:
Field.create_fixed_size(
"f",
(3, 2, 4),
index_dimensions=1,
dtype=struct_dtype,
layout="reverse_numpy",
)
assert "Structured arrays" in str(e.value)
f = Field.create_fixed_size("f", (10, 10))
with pytest.raises(ValueError) as e:
f[1]
assert "Wrong number of spatial indices" in str(e.value)
f = Field.create_generic("f", spatial_dimensions=2, index_shape=(3,))
with pytest.raises(ValueError) as e:
f(3)
assert "out of bounds" in str(e.value)
f = Field.create_fixed_size("f", (10, 10, 3, 4), index_dimensions=2)
with pytest.raises(ValueError) as e:
f(3, 0)
assert "out of bounds" in str(e.value)
with pytest.raises(ValueError) as e:
f(1, 0)(1, 0)
assert "Indexing an already indexed" in str(e.value)
with pytest.raises(ValueError) as e:
f(1)
assert "Wrong number of indices" in str(e.value)
with pytest.raises(ValueError) as e:
Field.create_generic("f", spatial_dimensions=2, layout="wrong")
assert "Unknown layout descriptor" in str(e.value)
assert layout_string_to_tuple("fzyx", dim=4) == (3, 2, 1, 0)
with pytest.raises(ValueError) as e:
layout_string_to_tuple("wrong", dim=4)
assert "Unknown layout descriptor" in str(e.value)
def test_decorator_scoping():
dst = fields("dst : double[2D]")
def f1():
a = sp.Symbol("a")
def f2():
b = sp.Symbol("b")
@ps.kernel
def decorated_func():
dst[0, 0] @= a + b
return decorated_func
return f2
assert f1()(), ps.Assignment(dst[0, 0], sp.Symbol("a") + sp.Symbol("b"))
def test_string_creation():
x, y, z = fields(" x(4), y(3,5) z : double[ 3, 47]")
assert x.index_shape == (4,)
assert y.index_shape == (3, 5)
assert z.spatial_shape == (3, 47)
def test_itemsize():
x = fields("x: float32[1d]")
y = fields("y: float64[2d]")
i = fields("i: int16[1d]")
assert x.itemsize == 4
assert y.itemsize == 8
assert i.itemsize == 2
def test_spatial_memory_layout_descriptors():
assert (
spatial_layout_string_to_tuple("AoS", 3)
== spatial_layout_string_to_tuple("aos", 3)
== spatial_layout_string_to_tuple("ZYXF", 3)
== spatial_layout_string_to_tuple("zyxf", 3)
== (2, 1, 0)
)
assert (
spatial_layout_string_to_tuple("SoA", 3)
== spatial_layout_string_to_tuple("soa", 3)
== spatial_layout_string_to_tuple("FZYX", 3)
== spatial_layout_string_to_tuple("fzyx", 3)
== spatial_layout_string_to_tuple("f", 3)
== spatial_layout_string_to_tuple("F", 3)
== (2, 1, 0)
)
assert (
spatial_layout_string_to_tuple("c", 3)
== spatial_layout_string_to_tuple("C", 3)
== (0, 1, 2)
)
assert spatial_layout_string_to_tuple("C", 5) == (0, 1, 2, 3, 4)
with pytest.raises(ValueError):
spatial_layout_string_to_tuple("aos", -1)
with pytest.raises(ValueError):
spatial_layout_string_to_tuple("aos", 4)
def test_memory_layout_descriptors():
assert (
layout_string_to_tuple("AoS", 4)
== layout_string_to_tuple("aos", 4)
== layout_string_to_tuple("ZYXF", 4)
== layout_string_to_tuple("zyxf", 4)
== (2, 1, 0, 3)
)
assert (
layout_string_to_tuple("SoA", 4)
== layout_string_to_tuple("soa", 4)
== layout_string_to_tuple("FZYX", 4)
== layout_string_to_tuple("fzyx", 4)
== layout_string_to_tuple("f", 4)
== layout_string_to_tuple("F", 4)
== (3, 2, 1, 0)
)
assert (
layout_string_to_tuple("c", 4)
== layout_string_to_tuple("C", 4)
== (0, 1, 2, 3)
)
assert layout_string_to_tuple("C", 5) == (0, 1, 2, 3, 4)
with pytest.raises(ValueError):
layout_string_to_tuple("aos", -1)
with pytest.raises(ValueError):
layout_string_to_tuple("aos", 5)
def test_staggered():
# D2Q5
j1, j2, j3 = fields(
"j1(2), j2(2,2), j3(2,2,2) : double[2D]", field_type=FieldType.STAGGERED
)
assert j1[0, 1](1) == j1.staggered_access((0, sp.Rational(1, 2)))
assert j1[0, 1](1) == j1.staggered_access(np.array((0, sp.Rational(1, 2))))
assert j1[1, 1](1) == j1.staggered_access((1, sp.Rational(1, 2)))
assert j1[0, 2](1) == j1.staggered_access((0, sp.Rational(3, 2)))
assert j1[0, 1](1) == j1.staggered_access("N")
assert j1[0, 0](1) == j1.staggered_access("S")
assert j1.staggered_vector_access("N") == sp.Matrix([j1.staggered_access("N")])
assert j1.staggered_stencil_name == "D2Q5"
assert j1.physical_coordinates[0] == DEFAULTS.spatial_counters[0]
assert j1.physical_coordinates[1] == DEFAULTS.spatial_counters[1]
assert j1.physical_coordinates_staggered[0] == DEFAULTS.spatial_counters[0] + 0.5
assert j1.physical_coordinates_staggered[1] == DEFAULTS.spatial_counters[1] + 0.5
assert (
j1.index_to_physical(index_coordinates=sp.Matrix([0, 0]), staggered=True)[0]
== 0.5
)
assert (
j1.index_to_physical(index_coordinates=sp.Matrix([0, 0]), staggered=True)[1]
== 0.5
)
assert (
j1.physical_to_index(physical_coordinates=sp.Matrix([0, 0]), staggered=True)[0]
== -0.5
)
assert (
j1.physical_to_index(physical_coordinates=sp.Matrix([0, 0]), staggered=True)[1]
== -0.5
)
assert j2[0, 1](1, 1) == j2.staggered_access((0, sp.Rational(1, 2)), 1)
assert j2[0, 1](1, 1) == j2.staggered_access("N", 1)
assert j2.staggered_vector_access("N") == sp.Matrix(
[j2.staggered_access("N", 0), j2.staggered_access("N", 1)]
)
assert j3[0, 1](1, 1, 1) == j3.staggered_access((0, sp.Rational(1, 2)), (1, 1))
assert j3[0, 1](1, 1, 1) == j3.staggered_access("N", (1, 1))
assert j3.staggered_vector_access("N") == sp.Matrix(
[[j3.staggered_access("N", (i, j)) for j in range(2)] for i in range(2)]
)
# D2Q9
k1, k2 = fields("k1(4), k2(2) : double[2D]", field_type=FieldType.STAGGERED)
assert k1[1, 1](2) == k1.staggered_access("NE")
assert k1[0, 0](2) == k1.staggered_access("SW")
assert k1[0, 0](3) == k1.staggered_access("NW")
a = k1.staggered_access("NE")
assert a._staggered_offset(a.offsets, a.index[0]) == [
sp.Rational(1, 2),
sp.Rational(1, 2),
]
a = k1.staggered_access("SW")
assert a._staggered_offset(a.offsets, a.index[0]) == [
sp.Rational(-1, 2),
sp.Rational(-1, 2),
]
a = k1.staggered_access("NW")
assert a._staggered_offset(a.offsets, a.index[0]) == [
sp.Rational(-1, 2),
sp.Rational(1, 2),
]
# sign reversed when using as flux field
r = fields("r(2) : double[2D]", field_type=FieldType.STAGGERED_FLUX)
assert r[0, 0](0) == r.staggered_access("W")
assert -r[1, 0](0) == r.staggered_access("E")
# test_staggered()
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
import sympy as sp import sympy as sp
import pystencils import pystencils
from pystencils.typing import create_type from pystencils.types import create_type
def test_floor_ceil_int_optimization(): def test_floor_ceil_int_optimization():
......
...@@ -11,12 +11,13 @@ ...@@ -11,12 +11,13 @@
import sympy as sp import sympy as sp
import pystencils import pystencils
from pystencils.typing import TypedSymbol, BasicType from pystencils.sympyextensions import TypedSymbol
from pystencils.types import create_type
def test_wild_typed_symbol(): def test_wild_typed_symbol():
x = pystencils.fields('x: float32[3d]') x = pystencils.fields('x: float32[3d]')
typed_symbol = TypedSymbol('a', BasicType('float64')) typed_symbol = TypedSymbol('a', create_type('float64'))
assert x.center().match(sp.Wild('w1')) assert x.center().match(sp.Wild('w1'))
assert typed_symbol.match(sp.Wild('w1')) assert typed_symbol.match(sp.Wild('w1'))
......
from copy import copy, deepcopy from copy import copy, deepcopy
from pystencils.field import Field from pystencils.field import Field
from pystencils.typing import TypedSymbol from pystencils.sympyextensions import TypedSymbol
from pystencils.types import create_type
def test_field_access(): def test_field_access():
...@@ -15,4 +16,4 @@ def test_typed_symbol(): ...@@ -15,4 +16,4 @@ def test_typed_symbol():
ts = TypedSymbol("s", "double") ts = TypedSymbol("s", "double")
copy(ts) copy(ts)
ts_copy = deepcopy(ts) ts_copy = deepcopy(ts)
assert str(ts_copy.dtype).strip() == "double" assert ts_copy.dtype == create_type("double")
import sympy as sp import sympy as sp
import pytest
import pystencils as ps import pystencils as ps
from pystencils import Assignment, AssignmentCollection from pystencils import Assignment, AssignmentCollection
...@@ -47,6 +48,8 @@ def test_simplification_strategy(): ...@@ -47,6 +48,8 @@ def test_simplification_strategy():
def test_split_inner_loop(): def test_split_inner_loop():
pytest.skip("Loop splitting not implemented yet")
dst = ps.fields('dst(8): double[2D]') dst = ps.fields('dst(8): double[2D]')
s = sp.symbols('s_:8') s = sp.symbols('s_:8')
x = sp.symbols('x') x = sp.symbols('x')
......
from sys import version_info as vs from sys import version_info as vs
import pytest import pytest
import pystencils.config
import sympy as sp import sympy as sp
import pystencils as ps import pystencils as ps
from pystencils import Assignment, AssignmentCollection, fields from pystencils import Assignment, AssignmentCollection, fields
from pystencils.simp import subexpression_substitution_in_main_assignments from pystencils.simp import (
from pystencils.simp import add_subexpressions_for_divisions subexpression_substitution_in_main_assignments,
from pystencils.simp import add_subexpressions_for_sums add_subexpressions_for_divisions,
from pystencils.simp import add_subexpressions_for_field_reads add_subexpressions_for_sums,
from pystencils.simp.simplifications import add_subexpressions_for_constants add_subexpressions_for_field_reads,
from pystencils.typing import BasicType, TypedSymbol add_subexpressions_for_constants,
)
from pystencils.sympyextensions import TypedSymbol
from pystencils.types import create_type
a, b, c, d, x, y, z = sp.symbols("a b c d x y z") a, b, c, d, x, y, z = sp.symbols("a b c d x y z")
s0, s1, s2, s3 = sp.symbols("s_:4") s0, s1, s2, s3 = sp.symbols("s_:4")
...@@ -144,37 +146,38 @@ def test_add_subexpressions_for_field_reads(): ...@@ -144,37 +146,38 @@ def test_add_subexpressions_for_field_reads():
ac3 = add_subexpressions_for_field_reads(ac, data_type="float32") ac3 = add_subexpressions_for_field_reads(ac, data_type="float32")
assert len(ac3.subexpressions) == 2 assert len(ac3.subexpressions) == 2
assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol) assert isinstance(ac3.subexpressions[0].lhs, TypedSymbol)
assert ac3.subexpressions[0].lhs.dtype == BasicType("float32") assert ac3.subexpressions[0].lhs.dtype == create_type("float32")
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) # TODO: What does this test mean to accomplish?
@pytest.mark.parametrize('dtype', ('float32', 'float64')) # @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason") # @pytest.mark.parametrize('dtype', ('float32', 'float64'))
def test_sympy_optimizations(target, dtype): # @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
if target == ps.Target.GPU: # def test_sympy_optimizations(target, dtype):
pytest.importorskip("cupy") # if target == ps.Target.GPU:
src, dst = ps.fields(f'src, dst: {dtype}[2d]') # pytest.importorskip("cupy")
# src, dst = ps.fields(f'src, dst: {dtype}[2d]')
assignments = ps.AssignmentCollection({ # assignments = ps.AssignmentCollection({
src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1) # src[0, 0]: 1.0 * (sp.exp(dst[0, 0]) - 1)
}) # })
config = pystencils.config.CreateKernelConfig(target=target, default_number_float=dtype) # config = pystencils.config.CreateKernelConfig(target=target, default_dtype=dtype)
ast = ps.create_kernel(assignments, config=config) # ast = ps.create_kernel(assignments, config=config)
ps.show_code(ast) # ps.show_code(ast)
code = ps.get_code_str(ast) # code = ps.get_code_str(ast)
if dtype == 'float32': # if dtype == 'float32':
assert 'expf(' in code # assert 'expf(' in code
elif dtype == 'float64': # elif dtype == 'float64':
assert 'exp(' in code # assert 'exp(' in code
@pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU)) @pytest.mark.parametrize('target', (ps.Target.CPU, ps.Target.GPU))
@pytest.mark.parametrize('simplification', (True, False))
@pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason") @pytest.mark.skipif((vs.major, vs.minor, vs.micro) == (3, 8, 2), reason="does not work on python 3.8.2 for some reason")
def test_evaluate_constant_terms(target, simplification): @pytest.mark.xfail(reason="The new backend does not (yet) evaluate transcendental functions")
def test_evaluate_constant_terms(target):
if target == ps.Target.GPU: if target == ps.Target.GPU:
pytest.importorskip("cupy") pytest.importorskip("cupy")
src, dst = ps.fields('src, dst: float32[2d]') src, dst = ps.fields('src, dst: float32[2d]')
...@@ -184,7 +187,7 @@ def test_evaluate_constant_terms(target, simplification): ...@@ -184,7 +187,7 @@ def test_evaluate_constant_terms(target, simplification):
src[0, 0]: -sp.cos(1) + dst[0, 0] src[0, 0]: -sp.cos(1) + dst[0, 0]
}) })
config = pystencils.config.CreateKernelConfig(target=target, default_assignment_simplifications=simplification) config = ps.CreateKernelConfig(target=target)
ast = ps.create_kernel(assignments, config=config) ast = ps.create_kernel(assignments, config=config)
code = ps.get_code_str(ast) code = ps.get_code_str(ast)
assert 'cos(' not in code assert 'cos(' not in code and 'cosf(' not in code
...@@ -3,6 +3,7 @@ import numpy as np ...@@ -3,6 +3,7 @@ import numpy as np
import sympy as sp import sympy as sp
import pystencils import pystencils
from pystencils import Assignment
from pystencils.sympyextensions import replace_second_order_products from pystencils.sympyextensions import replace_second_order_products
from pystencils.sympyextensions import remove_higher_order_terms from pystencils.sympyextensions import remove_higher_order_terms
from pystencils.sympyextensions import complete_the_squares_in_exp from pystencils.sympyextensions import complete_the_squares_in_exp
...@@ -13,11 +14,18 @@ from pystencils.sympyextensions import common_denominator ...@@ -13,11 +14,18 @@ from pystencils.sympyextensions import common_denominator
from pystencils.sympyextensions import get_symmetric_part from pystencils.sympyextensions import get_symmetric_part
from pystencils.sympyextensions import scalar_product from pystencils.sympyextensions import scalar_product
from pystencils.sympyextensions import kronecker_delta from pystencils.sympyextensions import kronecker_delta
from pystencils.sympyextensions.fast_approximation import (
from pystencils import Assignment fast_division,
from pystencils.functions import DivFunc fast_inv_sqrt,
from pystencils.fast_approximation import (fast_division, fast_inv_sqrt, fast_sqrt, fast_sqrt,
insert_fast_divisions, insert_fast_sqrts) insert_fast_divisions,
insert_fast_sqrts,
)
from pystencils.sympyextensions.integer_functions import (
round_to_multiple_towards_zero,
ceil_to_multiple,
div_ceil,
)
def test_utility(): def test_utility():
...@@ -40,10 +48,10 @@ def test_utility(): ...@@ -40,10 +48,10 @@ def test_utility():
def test_replace_second_order_products(): def test_replace_second_order_products():
x, y = sympy.symbols('x y') x, y = sympy.symbols("x y")
expr = 4 * x * y expr = 4 * x * y
expected_expr_positive = 2 * ((x + y) ** 2 - x ** 2 - y ** 2) expected_expr_positive = 2 * ((x + y) ** 2 - x**2 - y**2)
expected_expr_negative = 2 * (-(x - y) ** 2 + x ** 2 + y ** 2) expected_expr_negative = 2 * (-((x - y) ** 2) + x**2 + y**2)
result = replace_second_order_products(expr, search_symbols=[x, y], positive=True) result = replace_second_order_products(expr, search_symbols=[x, y], positive=True)
assert result == expected_expr_positive assert result == expected_expr_positive
...@@ -56,15 +64,17 @@ def test_replace_second_order_products(): ...@@ -56,15 +64,17 @@ def test_replace_second_order_products():
result = replace_second_order_products(expr, search_symbols=[x, y], positive=None) result = replace_second_order_products(expr, search_symbols=[x, y], positive=None)
assert result == expected_expr_positive assert result == expected_expr_positive
a = [Assignment(sympy.symbols('z'), x + y)] a = [Assignment(sympy.symbols("z"), x + y)]
replace_second_order_products(expr, search_symbols=[x, y], positive=True, replace_mixed=a) replace_second_order_products(
expr, search_symbols=[x, y], positive=True, replace_mixed=a
)
assert len(a) == 2 assert len(a) == 2
assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4 assert replace_second_order_products(4 + y, search_symbols=[x, y]) == y + 4
def test_remove_higher_order_terms(): def test_remove_higher_order_terms():
x, y = sympy.symbols('x y') x, y = sympy.symbols("x y")
expr = sympy.Mul(x, y) expr = sympy.Mul(x, y)
...@@ -82,19 +92,19 @@ def test_remove_higher_order_terms(): ...@@ -82,19 +92,19 @@ def test_remove_higher_order_terms():
def test_complete_the_squares_in_exp(): def test_complete_the_squares_in_exp():
a, b, c, s, n = sympy.symbols('a b c s n') a, b, c, s, n = sympy.symbols("a b c s n")
expr = a * s ** 2 + b * s + c expr = a * s**2 + b * s + c
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expr assert result == expr
expr = sympy.exp(a * s ** 2 + b * s + c) expr = sympy.exp(a * s**2 + b * s + c)
expected_result = sympy.exp(a*s**2 + c - b**2 / (4*a)) expected_result = sympy.exp(a * s**2 + c - b**2 / (4 * a))
result = complete_the_squares_in_exp(expr, symbols_to_complete=[s]) result = complete_the_squares_in_exp(expr, symbols_to_complete=[s])
assert result == expected_result assert result == expected_result
def test_extract_most_common_factor(): def test_extract_most_common_factor():
x, y = sympy.symbols('x y') x, y = sympy.symbols("x y")
expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y) expr = 1 / (x + y) + 3 / (x + y) + 3 / (x + y)
most_common_factor = extract_most_common_factor(expr) most_common_factor = extract_most_common_factor(expr)
...@@ -116,98 +126,98 @@ def test_extract_most_common_factor(): ...@@ -116,98 +126,98 @@ def test_extract_most_common_factor():
def test_count_operations(): def test_count_operations():
x, y, z = sympy.symbols('x y z') x, y, z = sympy.symbols("x y z")
expr = 1/x + y * sympy.sqrt(z) expr = 1 / x + y * sympy.sqrt(z)
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1 assert ops["adds"] == 1
assert ops['muls'] == 1 assert ops["muls"] == 1
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['sqrts'] == 1 assert ops["sqrts"] == 1
expr = 1 / sympy.sqrt(z) expr = 1 / sympy.sqrt(z)
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0 assert ops["adds"] == 0
assert ops['muls'] == 0 assert ops["muls"] == 0
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['sqrts'] == 1 assert ops["sqrts"] == 1
expr = sympy.Rel(1 / sympy.sqrt(z), 5) expr = sympy.Rel(1 / sympy.sqrt(z), 5)
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 0 assert ops["adds"] == 0
assert ops['muls'] == 0 assert ops["muls"] == 0
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['sqrts'] == 1 assert ops["sqrts"] == 1
expr = sympy.sqrt(x + y) expr = sympy.sqrt(x + y)
expr = insert_fast_sqrts(expr).atoms(fast_sqrt) expr = insert_fast_sqrts(expr).atoms(fast_sqrt)
ops = count_operations(*expr, only_type=None) ops = count_operations(*expr, only_type=None)
assert ops['fast_sqrts'] == 1 assert ops["fast_sqrts"] == 1
expr = sympy.sqrt(x / y) expr = sympy.sqrt(x / y)
expr = insert_fast_divisions(expr).atoms(fast_division) expr = insert_fast_divisions(expr).atoms(fast_division)
ops = count_operations(*expr, only_type=None) ops = count_operations(*expr, only_type=None)
assert ops['fast_div'] == 1 assert ops["fast_div"] == 1
expr = pystencils.Assignment(sympy.Symbol('tmp'), 3 / sympy.sqrt(x + y)) expr = pystencils.Assignment(sympy.Symbol("tmp"), 3 / sympy.sqrt(x + y))
expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt) expr = insert_fast_sqrts(expr).atoms(fast_inv_sqrt)
ops = count_operations(*expr, only_type=None) ops = count_operations(*expr, only_type=None)
assert ops['fast_inv_sqrts'] == 1 assert ops["fast_inv_sqrts"] == 1
expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z expr = sympy.Piecewise((1.0, x > 0), (0.0, True)) + y * z
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1 assert ops["adds"] == 1
expr = sympy.Pow(1/x + y * sympy.sqrt(z), 100) expr = sympy.Pow(1 / x + y * sympy.sqrt(z), 100)
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1 assert ops["adds"] == 1
assert ops['muls'] == 99 assert ops["muls"] == 99
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['sqrts'] == 1 assert ops["sqrts"] == 1
expr = DivFunc(x, y) expr = x / y
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1 assert ops["divs"] == 1
expr = DivFunc(x + z, y + z) expr = x + z / y + z
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 2 assert ops["adds"] == 2
assert ops['divs'] == 1 assert ops["divs"] == 1
expr = sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False)) expr = sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['muls'] == 99 assert ops["muls"] == 99
expr = DivFunc(1, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) expr = 1 / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['muls'] == 99 assert ops["muls"] == 99
expr = DivFunc(y + z, sp.UnevaluatedExpr(sp.Mul(*[x]*100, evaluate=False))) expr = (y + z) / sp.UnevaluatedExpr(sp.Mul(*[x] * 100, evaluate=False))
ops = count_operations(expr, only_type=None) ops = count_operations(expr, only_type=None)
assert ops['adds'] == 1 assert ops["adds"] == 1
assert ops['divs'] == 1 assert ops["divs"] == 1
assert ops['muls'] == 99 assert ops["muls"] == 99
def test_common_denominator(): def test_common_denominator():
x = sympy.symbols('x') x = sympy.symbols("x")
expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3) expr = sympy.Rational(1, 2) + x * sympy.Rational(2, 3)
cm = common_denominator(expr) cm = common_denominator(expr)
assert cm == 6 assert cm == 6
def test_get_symmetric_part(): def test_get_symmetric_part():
x, y, z = sympy.symbols('x y z') x, y, z = sympy.symbols("x y z")
expr = x / 9 - y ** 2 / 6 + z ** 2 / 3 + z / 3 expr = x / 9 - y**2 / 6 + z**2 / 3 + z / 3
expected_result = x / 9 - y ** 2 / 6 + z ** 2 / 3 expected_result = x / 9 - y**2 / 6 + z**2 / 3
sym_part = get_symmetric_part(expr, sympy.symbols(f'y z')) sym_part = get_symmetric_part(expr, sympy.symbols(f"y z"))
assert sym_part == expected_result assert sym_part == expected_result
def test_simplify_by_equality(): def test_simplify_by_equality():
x, y, z = sp.symbols('x, y, z') x, y, z = sp.symbols("x, y, z")
p, q = sp.symbols('p, q') p, q = sp.symbols("p, q")
# Let x = y + z # Let x = y + z
expr = x * p - y * p + z * q expr = x * p - y * p + z * q
...@@ -220,9 +230,24 @@ def test_simplify_by_equality(): ...@@ -220,9 +230,24 @@ def test_simplify_by_equality():
expr = x * (y + z) - y * z expr = x * (y + z) - y * z
expr = simplify_by_equality(expr, x, y, z) expr = simplify_by_equality(expr, x, y, z)
assert expr == x*y + z**2 assert expr == x * y + z**2
# Let x = y + 2 # Let x = y + 2
expr = x * p - 2 * p expr = x * p - 2 * p
expr = simplify_by_equality(expr, x, y, 2) expr = simplify_by_equality(expr, x, y, 2)
assert expr == y * p assert expr == y * p
def test_integer_functions():
assert round_to_multiple_towards_zero(9, 4) == 8
assert round_to_multiple_towards_zero(11, -4) == 8
assert round_to_multiple_towards_zero(12, 4) == 12
assert round_to_multiple_towards_zero(-9, 4) == -8
assert round_to_multiple_towards_zero(-9, -4) == -8
assert ceil_to_multiple(9, 4) == 12
assert ceil_to_multiple(11, 4) == 12
assert ceil_to_multiple(12, 4) == 12
assert div_ceil(9, 4) == 3
assert div_ceil(8, 4) == 2