Skip to content
Snippets Groups Projects
Commit 256a6b5b authored by Markus Holzer's avatar Markus Holzer
Browse files

Checked in simple mul problem

parent 6615a379
No related branches found
No related tags found
No related merge requests found
Pipeline #43393 failed
......@@ -661,6 +661,16 @@ class VectorizedCustomSympyPrinter(CustomSympyPrinter):
# if k in code:
# return code.replace(k, f'{k}f')
# raise ValueError(f"{code} doesn't give {known=} function back.")
# elif isinstance(arg, sp.UnevaluatedExpr):
# unevaluated_expression = arg.args[0]
# test = self._print_Mul(unevaluated_expression, inside_add=True)
# # printed_args = [self._typed_vectorized_symbol(a, data_type) for a in arg.args]
# printed_args = [self._print(a) for a in unevaluated_expression.args]
# return test[1]
#
# # return self.instruction_set[instruction].format(*printed_args, **self._kwargs)
#
# return printed_args
else:
raise NotImplementedError('Vectorizer cannot cast between different datatypes')
# to_type = self.instruction_set['suffix'][data_type.base_type.c_name]
......
......@@ -32,7 +32,7 @@ def test_vector_type_propagation(instruction_set=instruction_set):
ast = ps.create_kernel(update_rule)
vectorize(ast, instruction_set=instruction_set)
# ps.show_code(ast)
ps.show_code(ast)
func = ast.compile()
dst = np.zeros_like(arr)
......@@ -40,6 +40,26 @@ def test_vector_type_propagation(instruction_set=instruction_set):
np.testing.assert_equal(dst[1:-1, 1:-1], 2 * 10.0 + 3)
def test_vectorised_multiplication(instruction_set=instruction_set):
instructions = get_vector_instruction_set(instruction_set=instruction_set)
a, b = sp.symbols("a b")
f = ps.fields("f:[2D]")
update_rule = ps.Assignment(f[0, 0], a * b)
ast = ps.create_kernel(update_rule)
print(instruction_set)
vectorize(ast, instruction_set=instruction_set)
# ps.show_code(ast)
func = ast.compile()
code = ps.get_code_str(ast)
mul_instruction = instructions["*"][:instructions["*"].find("(")]
assert mul_instruction in code
@pytest.mark.parametrize('openmp', [True, False])
def test_aligned_and_nt_stores(openmp, instruction_set=instruction_set):
domain_size = (24, 24)
......@@ -302,8 +322,7 @@ def test_issue40(*_):
def test_issue62(*_):
opt = {'instruction_set': "avx", 'assume_aligned': True,
'assume_inner_stride_one': True}
opt = {'instruction_set': "avx", 'assume_aligned': True, 'assume_inner_stride_one': True}
field_type = "float64" # if ctx.double_accuracy else "float32"
# ----- Solving the 2D Poisson equation with rhs --------------------------
......@@ -321,5 +340,7 @@ def test_issue62(*_):
ast = ps.create_kernel(up, config=config)
code = ps.get_code_str(ast)
print(code)
assert 'pow' not in code
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment