Skip to content
Snippets Groups Projects
Commit 9f3406f6 authored by Markus Holzer's avatar Markus Holzer
Browse files

Fix CFunction in parameters

parent dc2f8d99
No related branches found
No related tags found
1 merge request!412[FIX] Vector Size for SVE instruction set is in free kernel parameters
Pipeline #69185 failed
......@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.typing.utilities import create_type, get_next_parent_of_type
from pystencils.assignment import Assignment
from pystencils.enums import Target, Backend
from pystencils.field import Field
from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.typing import (CFunction, create_type, get_next_parent_of_type,
FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol)
NodeOrExpr = Union['Node', sp.Expr]
......@@ -270,7 +270,9 @@ class KernelFunction(Node):
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
parameters = [p for p in parameters if p.symbol.name != "svcntd()"]
# Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
# by including a respective header file in the compute kernel. Hence, it is not a free parameter.
# parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
parameters.sort(key=lambda p: p.symbol.name)
return parameters
......@@ -388,7 +390,7 @@ class Block(Node):
def symbols_defined(self):
result = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
else:
result.update(a.symbols_defined)
......@@ -399,7 +401,7 @@ class Block(Node):
result = set()
defined_symbols = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
defined_symbols.update({a.lhs})
else:
......
......@@ -6,7 +6,6 @@ from typing import Set
import numpy as np
import sympy as sp
from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
......@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.typing import (
PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.functions import DivFunc, AddressOf
......@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------
......
......@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol)
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type',
'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
......@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
\ No newline at end of file
......@@ -8,8 +8,10 @@ import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
get_vector_instruction_set)
from . import test_vectorization
from pystencils.enums import Target
from pystencils.typing import CFunction
from . import test_vectorization
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
......@@ -32,6 +34,8 @@ def test_vectorisation_varying_arch(instruction_set):
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
kernel = ast.compile()
kernel(f=arr)
np.testing.assert_equal(arr, 2)
......@@ -51,6 +55,8 @@ def test_vectorized_abs(instruction_set, dtype):
config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
func = ast.compile()
dst = np.zeros_like(arr)
......@@ -118,6 +124,8 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set
config = pystencils.config.CreateKernelConfig(target=dh.default_target,
cpu_vectorize_info=opt, ghost_layers=gl_kernel)
ast = ps.create_kernel(update_rule, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
kernel = ast.compile()
if ('loadA' in ast.instruction_set or 'storeA' in ast.instruction_set) and gl_kernel != gl_field:
with pytest.raises(ValueError):
......@@ -165,6 +173,8 @@ def test_square_root(dtype, instruction_set, field_layout):
ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.sqrt(src_field.center))]
ast = ps.create_kernel(eq, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
ast.compile()
code = ps.get_code_str(ast)
print(code)
......@@ -185,6 +195,8 @@ def test_square_root_2(dtype, instruction_set, padding):
config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, cpu_vectorize_info=cpu_vec)
ast = ps.create_kernel(up, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
ast.compile()
code = ps.get_code_str(ast)
......@@ -208,6 +220,8 @@ def test_pow(dtype, instruction_set, padding):
ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.Pow(src_field.center, 0.5))]
ast = ps.create_kernel(eq, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
ast.compile()
code = ps.get_code_str(ast)
......@@ -235,6 +249,8 @@ def test_issue62(dtype, instruction_set, padding):
cpu_vectorize_info=opt)
ast = ps.create_kernel(up, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
ast.compile()
code = ps.get_code_str(ast)
......@@ -261,6 +277,8 @@ def test_div_and_unevaluated_expr(dtype, instruction_set):
cpu_vectorize_info=opt)
ast = ps.create_kernel(up, config=config)
for parameter in ast.get_parameters():
assert not isinstance(parameter.symbol, CFunction)
code = ps.get_code_str(ast)
# print(code)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment