diff --git a/src/pystencils/astnodes.py b/src/pystencils/astnodes.py index f399287ed02ec4eb0d3d0e295d72ee1cf5ecc14b..c8928a218afeecc1c914a55696826ed7d055cf62 100644 --- a/src/pystencils/astnodes.py +++ b/src/pystencils/astnodes.py @@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp -import pystencils -from pystencils.typing.utilities import create_type, get_next_parent_of_type +from pystencils.assignment import Assignment from pystencils.enums import Target, Backend from pystencils.field import Field -from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol from pystencils.sympyextensions import fast_subs +from pystencils.typing import (create_type, get_next_parent_of_type, + FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction) NodeOrExpr = Union['Node', sp.Expr] @@ -270,6 +270,9 @@ class KernelFunction(Node): parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] if hasattr(self, 'indexing'): parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] + # Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled + # by including a respective header file in the compute kernel. Hence, it is not a free parameter. + parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)] parameters.sort(key=lambda p: p.symbol.name) return parameters @@ -387,7 +390,7 @@ class Block(Node): def symbols_defined(self): result = set() for a in self.args: - if isinstance(a, pystencils.Assignment): + if isinstance(a, Assignment): result.update(a.free_symbols) else: result.update(a.symbols_defined) @@ -398,7 +401,7 @@ class Block(Node): result = set() defined_symbols = set() for a in self.args: - if isinstance(a, pystencils.Assignment): + if isinstance(a, Assignment): result.update(a.free_symbols) defined_symbols.update({a.lhs}) else: diff --git a/src/pystencils/backends/arm_instruction_sets.py b/src/pystencils/backends/arm_instruction_sets.py index 227224f4e65460a291bd3a6cd3309ed3525072fa..fc0a8c45040879c2d1f9cd21ae91f5dcc49a35a2 100644 --- a/src/pystencils/backends/arm_instruction_sets.py +++ b/src/pystencils/backends/arm_instruction_sets.py @@ -1,3 +1,6 @@ +from pystencils.typing import CFunction + + def get_argument_string(function_shortcut, first=''): args = function_shortcut[function_shortcut.index('[') + 1: -1] arg_string = "(" @@ -66,10 +69,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): if instruction_set.startswith('sve') or instruction_set == 'sme': base_names['stream'] = 'stnt1[0, 1]' prefix = 'sv' - suffix = f'_f{bits[data_type]}' + suffix = f'_f{bits[data_type]}' elif instruction_set == 'neon': prefix = 'v' - suffix = f'q_f{bits[data_type]}' + suffix = f'q_f{bits[data_type]}' if instruction_set in ['sve', 'sve2', 'sme']: predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})' @@ -91,7 +94,6 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result[intrinsic_id] = prefix + name + suffix + undef + arg_string if instruction_set in ['sve', 'sve2', 'sme']: - from pystencils.backends.cbackend import CFunction result['width'] = CFunction(width, "int") result['intwidth'] = CFunction(intwidth, "int") else: @@ -134,7 +136,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'): result['maskStoreS'] = result['storeS'].replace(predicate, '{3}') if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'): result['maskStreamS'] = result['streamS'].replace(predicate, '{3}') - + result['streamFence'] = '__dmb(15)' if instruction_set == 'sme': diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py index 657f60d2f16f14a20f81ebfc77414eb31ba0236a..5d664760720eee0f9dc9a799e1ca731ae08df1d3 100644 --- a/src/pystencils/backends/cbackend.py +++ b/src/pystencils/backends/cbackend.py @@ -6,7 +6,6 @@ from typing import Set import numpy as np import sympy as sp from sympy.core import S -from sympy.core.cache import cacheit from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction @@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.typing import ( PointerType, VectorType, CastFunc, create_type, get_type_of_expression, - ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) + ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.functions import DivFunc, AddressOf @@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode): self.headers.append("<iostream>") -class CFunction(TypedSymbol): - def __new__(cls, function, dtype): - return CFunction.__xnew_cached_(cls, function, dtype) - - def __new_stage2__(cls, function, dtype): - return super(CFunction, cls).__xnew__(cls, function, dtype) - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) - - def __getnewargs__(self): - return self.name, self.dtype - - def __getnewargs_ex__(self): - return (self.name, self.dtype), {} - - # ------------------------------------------- Printer ------------------------------------------------------------------ diff --git a/src/pystencils/backends/riscv_instruction_sets.py b/src/pystencils/backends/riscv_instruction_sets.py index 27f631e7f92d25e366bc767c759697ac898f3308..0b303393e93fb137fdc78f7ea47be2ec9b49027c 100644 --- a/src/pystencils/backends/riscv_instruction_sets.py +++ b/src/pystencils/backends/riscv_instruction_sets.py @@ -1,3 +1,6 @@ +from pystencils.typing import CFunction + + def get_argument_string(function_shortcut, last=''): args = function_shortcut[function_shortcut.index('[') + 1: -1] arg_string = "(" @@ -78,7 +81,6 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'): result[intrinsic_id] = prefix + name + suffix2 + arg_string - from pystencils.backends.cbackend import CFunction result['width'] = CFunction(width, "int") result['intwidth'] = CFunction(intwidth, "int") diff --git a/src/pystencils/cpu/cpujit.py b/src/pystencils/cpu/cpujit.py index 8cc10004505dffb6251b61474c2db1321ba7cffd..d9a320e7674d9d5ebec9d903e76bb79dd549b606 100644 --- a/src/pystencils/cpu/cpujit.py +++ b/src/pystencils/cpu/cpujit.py @@ -62,7 +62,7 @@ import numpy as np from pystencils import FieldType from pystencils.astnodes import LoopOverCoordinate -from pystencils.backends.cbackend import generate_c, get_headers, CFunction +from pystencils.backends.cbackend import generate_c, get_headers from pystencils.cpu.msvc_detection import get_environment from pystencils.include import get_pystencils_include_path from pystencils.kernel_wrapper import KernelWrapper @@ -447,8 +447,6 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec parameters.append(f"buffer_{field.name}.strides[{param.symbol.coordinate}] / {item_size}") elif param.is_field_shape: parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]") - elif type(param.symbol) is CFunction: - continue else: extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type] pre_call_code += template_extract_scalar.format(extract_function=extract_function, diff --git a/src/pystencils/typing/__init__.py b/src/pystencils/typing/__init__.py index ae4483da44f9ddfb43e365d0f16e6ea2d9dc97c2..ae3974d82e07156a6063b62a8a2fa1b781ae0dcd 100644 --- a/src/pystencils/typing/__init__.py +++ b/src/pystencils/typing/__init__.py @@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType, PointerType, StructType, create_type) from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol, - FieldPointerSymbol) + FieldPointerSymbol, CFunction) from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types, get_type_of_expression, get_next_parent_of_type, parents_of_type) __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc', 'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType', - 'VectorType', 'PointerType', 'StructType', 'create_type', - 'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', + 'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype', + 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction', 'typed_symbols', 'get_base_type', 'result_type', 'collate_types', 'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type'] diff --git a/src/pystencils/typing/typed_sympy.py b/src/pystencils/typing/typed_sympy.py index 302c2f9987b2db1a907710678ddbb7234668cfc6..03471228a4ffdb02a98cdd064d235d4b0b9da397 100644 --- a/src/pystencils/typing/typed_sympy.py +++ b/src/pystencils/typing/typed_sympy.py @@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol): __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + +class CFunction(TypedSymbol): + def __new__(cls, function, dtype): + return CFunction.__xnew_cached_(cls, function, dtype) + + def __new_stage2__(cls, function, dtype): + return super(CFunction, cls).__xnew__(cls, function, dtype) + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + def __getnewargs__(self): + return self.name, self.dtype + + def __getnewargs_ex__(self): + return (self.name, self.dtype), {} diff --git a/tests/test_vectorization_specific.py b/tests/test_vectorization_specific.py index 19c6e0033c1b73a967d18cc36fbb93438c7359f5..48dd20f51f052c1a9d0c78fc385512e5ad33d29d 100644 --- a/tests/test_vectorization_specific.py +++ b/tests/test_vectorization_specific.py @@ -8,8 +8,10 @@ import sympy as sp import pystencils as ps from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets, get_vector_instruction_set) -from . import test_vectorization from pystencils.enums import Target +from pystencils.typing import CFunction +from . import test_vectorization + supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else [] @@ -274,6 +276,22 @@ def test_div_and_unevaluated_expr(dtype, instruction_set): assert 'pow' not in code +@pytest.mark.parametrize('dtype', ('float32', 'float64')) +@pytest.mark.parametrize('instruction_set', ('sve', 'sve2', 'sme', 'rvv')) +def test_check_ast_parameters_sizeless(dtype, instruction_set): + f, g = ps.fields(f"f, g: {dtype}[3D]", layout='fzyx') + + update_rule = [ps.Assignment(g.center(), 2 * f.center())] + + config = pystencils.config.CreateKernelConfig(data_type=dtype, + cpu_vectorize_info={'instruction_set': instruction_set}) + ast = ps.create_kernel(update_rule, config=config) + ast_symbols = [p.symbol for p in ast.get_parameters()] + assert ast.instruction_set['width'] not in ast_symbols + assert ast.instruction_set['intwidth'] not in ast_symbols + + + # TODO this test case needs a complete rework of the vectoriser. The reason is that the vectoriser does not # TODO vectorise symbols at the moment because they could be strides or field sizes, thus involved in pointer arithmetic # TODO This means that the vectoriser only works if fields are involved on the rhs.