Skip to content
Snippets Groups Projects
Commit 5a9e9787 authored by Michael Kuron's avatar Michael Kuron :mortar_board:
Browse files

Merge branch 'holzer-master-patch-07296' into 'master'

[FIX] Vector Size for SVE instruction set is in free kernel parameters

See merge request !412
parents a7d1b278 97ba71de
Branches
Tags release/1.3.6
1 merge request!412[FIX] Vector Size for SVE instruction set is in free kernel parameters
Pipeline #69486 passed
......@@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union
import sympy as sp
import pystencils
from pystencils.typing.utilities import create_type, get_next_parent_of_type
from pystencils.assignment import Assignment
from pystencils.enums import Target, Backend
from pystencils.field import Field
from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol
from pystencils.sympyextensions import fast_subs
from pystencils.typing import (create_type, get_next_parent_of_type,
FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol, CFunction)
NodeOrExpr = Union['Node', sp.Expr]
......@@ -270,6 +270,9 @@ class KernelFunction(Node):
parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols]
if hasattr(self, 'indexing'):
parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()]
# Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled
# by including a respective header file in the compute kernel. Hence, it is not a free parameter.
parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)]
parameters.sort(key=lambda p: p.symbol.name)
return parameters
......@@ -387,7 +390,7 @@ class Block(Node):
def symbols_defined(self):
result = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
else:
result.update(a.symbols_defined)
......@@ -398,7 +401,7 @@ class Block(Node):
result = set()
defined_symbols = set()
for a in self.args:
if isinstance(a, pystencils.Assignment):
if isinstance(a, Assignment):
result.update(a.free_symbols)
defined_symbols.update({a.lhs})
else:
......
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, first=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
......@@ -66,10 +69,10 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
if instruction_set.startswith('sve') or instruction_set == 'sme':
base_names['stream'] = 'stnt1[0, 1]'
prefix = 'sv'
suffix = f'_f{bits[data_type]}'
suffix = f'_f{bits[data_type]}'
elif instruction_set == 'neon':
prefix = 'v'
suffix = f'q_f{bits[data_type]}'
suffix = f'q_f{bits[data_type]}'
if instruction_set in ['sve', 'sve2', 'sme']:
predicate = f'{prefix}whilelt_b{bits[data_type]}_u64({{loop_counter}}, {{loop_stop}})'
......@@ -91,7 +94,6 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result[intrinsic_id] = prefix + name + suffix + undef + arg_string
if instruction_set in ['sve', 'sve2', 'sme']:
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
else:
......@@ -134,7 +136,7 @@ def get_vector_instruction_set_arm(data_type='double', instruction_set='neon'):
result['maskStoreS'] = result['storeS'].replace(predicate, '{3}')
if instruction_set.startswith('sve2') and instruction_set not in ('sve256', 'sve2048'):
result['maskStreamS'] = result['streamS'].replace(predicate, '{3}')
result['streamFence'] = '__dmb(15)'
if instruction_set == 'sme':
......
......@@ -6,7 +6,6 @@ from typing import Set
import numpy as np
import sympy as sp
from sympy.core import S
from sympy.core.cache import cacheit
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction
from sympy.functions.elementary.hyperbolic import HyperbolicFunction
......@@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node
from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize
from pystencils.typing import (
PointerType, VectorType, CastFunc, create_type, get_type_of_expression,
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol)
ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction)
from pystencils.enums import Backend
from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt
from pystencils.functions import DivFunc, AddressOf
......@@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode):
self.headers.append("<iostream>")
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
# ------------------------------------------- Printer ------------------------------------------------------------------
......
from pystencils.typing import CFunction
def get_argument_string(function_shortcut, last=''):
args = function_shortcut[function_shortcut.index('[') + 1: -1]
arg_string = "("
......@@ -78,7 +81,6 @@ def get_vector_instruction_set_riscv(data_type='double', instruction_set='rvv'):
result[intrinsic_id] = prefix + name + suffix2 + arg_string
from pystencils.backends.cbackend import CFunction
result['width'] = CFunction(width, "int")
result['intwidth'] = CFunction(intwidth, "int")
......
......@@ -62,7 +62,7 @@ import numpy as np
from pystencils import FieldType
from pystencils.astnodes import LoopOverCoordinate
from pystencils.backends.cbackend import generate_c, get_headers, CFunction
from pystencils.backends.cbackend import generate_c, get_headers
from pystencils.cpu.msvc_detection import get_environment
from pystencils.include import get_pystencils_include_path
from pystencils.kernel_wrapper import KernelWrapper
......@@ -447,8 +447,6 @@ def create_function_boilerplate_code(parameter_info, name, ast_node, insert_chec
parameters.append(f"buffer_{field.name}.strides[{param.symbol.coordinate}] / {item_size}")
elif param.is_field_shape:
parameters.append(f"buffer_{param.field_name}.shape[{param.symbol.coordinate}]")
elif type(param.symbol) is CFunction:
continue
else:
extract_function, target_type = type_mapping[param.symbol.dtype.numpy_dtype.type]
pre_call_code += template_extract_scalar.format(extract_function=extract_function,
......
......@@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM
from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType,
PointerType, StructType, create_type)
from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol,
FieldPointerSymbol)
FieldPointerSymbol, CFunction)
from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types,
get_type_of_expression, get_next_parent_of_type, parents_of_type)
__all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc',
'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType',
'VectorType', 'PointerType', 'StructType', 'create_type',
'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol',
'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype',
'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction',
'typed_symbols', 'get_base_type', 'result_type', 'collate_types',
'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type']
......@@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol):
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
class CFunction(TypedSymbol):
def __new__(cls, function, dtype):
return CFunction.__xnew_cached_(cls, function, dtype)
def __new_stage2__(cls, function, dtype):
return super(CFunction, cls).__xnew__(cls, function, dtype)
__xnew__ = staticmethod(__new_stage2__)
__xnew_cached_ = staticmethod(cacheit(__new_stage2__))
def __getnewargs__(self):
return self.name, self.dtype
def __getnewargs_ex__(self):
return (self.name, self.dtype), {}
......@@ -8,8 +8,10 @@ import sympy as sp
import pystencils as ps
from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets,
get_vector_instruction_set)
from . import test_vectorization
from pystencils.enums import Target
from pystencils.typing import CFunction
from . import test_vectorization
supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else []
......@@ -292,6 +294,22 @@ def test_div_and_unevaluated_expr(dtype, instruction_set):
assert 'pow' not in code
@pytest.mark.parametrize('dtype', ('float32', 'float64'))
@pytest.mark.parametrize('instruction_set', ('sve', 'sve2', 'sme', 'rvv'))
def test_check_ast_parameters_sizeless(dtype, instruction_set):
f, g = ps.fields(f"f, g: {dtype}[3D]", layout='fzyx')
update_rule = [ps.Assignment(g.center(), 2 * f.center())]
config = pystencils.config.CreateKernelConfig(data_type=dtype,
cpu_vectorize_info={'instruction_set': instruction_set})
ast = ps.create_kernel(update_rule, config=config)
ast_symbols = [p.symbol for p in ast.get_parameters()]
assert ast.instruction_set['width'] not in ast_symbols
assert ast.instruction_set['intwidth'] not in ast_symbols
# TODO this test case needs a complete rework of the vectoriser. The reason is that the vectoriser does not
# TODO vectorise symbols at the moment because they could be strides or field sizes, thus involved in pointer arithmetic
# TODO This means that the vectoriser only works if fields are involved on the rhs.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment