From 9f3406f63b56d1ae2044e5618a3e625263b79f0d Mon Sep 17 00:00:00 2001 From: Markus Holzer <markus.holzer@fau.de> Date: Wed, 25 Sep 2024 13:03:29 +0200 Subject: [PATCH] Fix CFunction in parameters --- src/pystencils/astnodes.py | 14 ++++++++------ src/pystencils/backends/cbackend.py | 20 +------------------- src/pystencils/typing/__init__.py | 6 +++--- src/pystencils/typing/typed_sympy.py | 17 +++++++++++++++++ tests/test_vectorization_specific.py | 20 +++++++++++++++++++- 5 files changed, 48 insertions(+), 29 deletions(-) diff --git a/src/pystencils/astnodes.py b/src/pystencils/astnodes.py index f0755a418..a54d350bc 100644 --- a/src/pystencils/astnodes.py +++ b/src/pystencils/astnodes.py @@ -5,12 +5,12 @@ from typing import Any, List, Optional, Sequence, Set, Union import sympy as sp -import pystencils -from pystencils.typing.utilities import create_type, get_next_parent_of_type +from pystencils.assignment import Assignment from pystencils.enums import Target, Backend from pystencils.field import Field -from pystencils.typing.typed_sympy import FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol from pystencils.sympyextensions import fast_subs +from pystencils.typing import (CFunction, create_type, get_next_parent_of_type, + FieldPointerSymbol, FieldShapeSymbol, FieldStrideSymbol, TypedSymbol) NodeOrExpr = Union['Node', sp.Expr] @@ -270,7 +270,9 @@ class KernelFunction(Node): parameters = [self.Parameter(symbol, get_fields(symbol)) for symbol in argument_symbols] if hasattr(self, 'indexing'): parameters += [self.Parameter(s, []) for s in self.indexing.symbolic_parameters()] - parameters = [p for p in parameters if p.symbol.name != "svcntd()"] + # Exclude paramters of type CFunction. These parameters will result in a C function call that will be handled + # by including a respective header file in the compute kernel. Hence, it is not a free parameter. + # parameters = [p for p in parameters if not isinstance(p.symbol, CFunction)] parameters.sort(key=lambda p: p.symbol.name) return parameters @@ -388,7 +390,7 @@ class Block(Node): def symbols_defined(self): result = set() for a in self.args: - if isinstance(a, pystencils.Assignment): + if isinstance(a, Assignment): result.update(a.free_symbols) else: result.update(a.symbols_defined) @@ -399,7 +401,7 @@ class Block(Node): result = set() defined_symbols = set() for a in self.args: - if isinstance(a, pystencils.Assignment): + if isinstance(a, Assignment): result.update(a.free_symbols) defined_symbols.update({a.lhs}) else: diff --git a/src/pystencils/backends/cbackend.py b/src/pystencils/backends/cbackend.py index 7dbf84d37..7065f7feb 100644 --- a/src/pystencils/backends/cbackend.py +++ b/src/pystencils/backends/cbackend.py @@ -6,7 +6,6 @@ from typing import Set import numpy as np import sympy as sp from sympy.core import S -from sympy.core.cache import cacheit from sympy.logic.boolalg import BooleanFalse, BooleanTrue from sympy.functions.elementary.trigonometric import TrigonometricFunction, InverseTrigonometricFunction from sympy.functions.elementary.hyperbolic import HyperbolicFunction @@ -15,7 +14,7 @@ from pystencils.astnodes import KernelFunction, LoopOverCoordinate, Node from pystencils.cpu.vectorization import vec_all, vec_any, CachelineSize from pystencils.typing import ( PointerType, VectorType, CastFunc, create_type, get_type_of_expression, - ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol) + ReinterpretCastFunc, VectorMemoryAccess, BasicType, TypedSymbol, CFunction) from pystencils.enums import Backend from pystencils.fast_approximation import fast_division, fast_inv_sqrt, fast_sqrt from pystencils.functions import DivFunc, AddressOf @@ -166,23 +165,6 @@ class PrintNode(CustomCodeNode): self.headers.append("<iostream>") -class CFunction(TypedSymbol): - def __new__(cls, function, dtype): - return CFunction.__xnew_cached_(cls, function, dtype) - - def __new_stage2__(cls, function, dtype): - return super(CFunction, cls).__xnew__(cls, function, dtype) - - __xnew__ = staticmethod(__new_stage2__) - __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) - - def __getnewargs__(self): - return self.name, self.dtype - - def __getnewargs_ex__(self): - return (self.name, self.dtype), {} - - # ------------------------------------------- Printer ------------------------------------------------------------------ diff --git a/src/pystencils/typing/__init__.py b/src/pystencils/typing/__init__.py index ae4483da4..ae3974d82 100644 --- a/src/pystencils/typing/__init__.py +++ b/src/pystencils/typing/__init__.py @@ -3,14 +3,14 @@ from pystencils.typing.cast_functions import (CastFunc, BooleanCastFunc, VectorM from pystencils.typing.types import (is_supported_type, numpy_name_to_c, AbstractType, BasicType, VectorType, PointerType, StructType, create_type) from pystencils.typing.typed_sympy import (assumptions_from_dtype, TypedSymbol, FieldStrideSymbol, FieldShapeSymbol, - FieldPointerSymbol) + FieldPointerSymbol, CFunction) from pystencils.typing.utilities import (typed_symbols, get_base_type, result_type, collate_types, get_type_of_expression, get_next_parent_of_type, parents_of_type) __all__ = ['CastFunc', 'BooleanCastFunc', 'VectorMemoryAccess', 'ReinterpretCastFunc', 'PointerArithmeticFunc', 'is_supported_type', 'numpy_name_to_c', 'AbstractType', 'BasicType', - 'VectorType', 'PointerType', 'StructType', 'create_type', - 'assumptions_from_dtype', 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', + 'VectorType', 'PointerType', 'StructType', 'create_type', 'assumptions_from_dtype', + 'TypedSymbol', 'FieldStrideSymbol', 'FieldShapeSymbol', 'FieldPointerSymbol', 'CFunction', 'typed_symbols', 'get_base_type', 'result_type', 'collate_types', 'get_type_of_expression', 'get_next_parent_of_type', 'parents_of_type'] diff --git a/src/pystencils/typing/typed_sympy.py b/src/pystencils/typing/typed_sympy.py index 302c2f998..bbaf3ef91 100644 --- a/src/pystencils/typing/typed_sympy.py +++ b/src/pystencils/typing/typed_sympy.py @@ -178,3 +178,20 @@ class FieldPointerSymbol(TypedSymbol): __xnew__ = staticmethod(__new_stage2__) __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + +class CFunction(TypedSymbol): + def __new__(cls, function, dtype): + return CFunction.__xnew_cached_(cls, function, dtype) + + def __new_stage2__(cls, function, dtype): + return super(CFunction, cls).__xnew__(cls, function, dtype) + + __xnew__ = staticmethod(__new_stage2__) + __xnew_cached_ = staticmethod(cacheit(__new_stage2__)) + + def __getnewargs__(self): + return self.name, self.dtype + + def __getnewargs_ex__(self): + return (self.name, self.dtype), {} \ No newline at end of file diff --git a/tests/test_vectorization_specific.py b/tests/test_vectorization_specific.py index dcebeae60..26a1ae3ec 100644 --- a/tests/test_vectorization_specific.py +++ b/tests/test_vectorization_specific.py @@ -8,8 +8,10 @@ import sympy as sp import pystencils as ps from pystencils.backends.simd_instruction_sets import (get_cacheline_size, get_supported_instruction_sets, get_vector_instruction_set) -from . import test_vectorization from pystencils.enums import Target +from pystencils.typing import CFunction +from . import test_vectorization + supported_instruction_sets = get_supported_instruction_sets() if get_supported_instruction_sets() else [] @@ -32,6 +34,8 @@ def test_vectorisation_varying_arch(instruction_set): config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) ast = ps.create_kernel(update_rule, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) kernel = ast.compile() kernel(f=arr) np.testing.assert_equal(arr, 2) @@ -51,6 +55,8 @@ def test_vectorized_abs(instruction_set, dtype): config = pystencils.config.CreateKernelConfig(cpu_vectorize_info={'instruction_set': instruction_set}) ast = ps.create_kernel(update_rule, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) func = ast.compile() dst = np.zeros_like(arr) @@ -118,6 +124,8 @@ def test_alignment_and_correct_ghost_layers(gl_field, gl_kernel, instruction_set config = pystencils.config.CreateKernelConfig(target=dh.default_target, cpu_vectorize_info=opt, ghost_layers=gl_kernel) ast = ps.create_kernel(update_rule, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) kernel = ast.compile() if ('loadA' in ast.instruction_set or 'storeA' in ast.instruction_set) and gl_kernel != gl_field: with pytest.raises(ValueError): @@ -165,6 +173,8 @@ def test_square_root(dtype, instruction_set, field_layout): ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.sqrt(src_field.center))] ast = ps.create_kernel(eq, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) ast.compile() code = ps.get_code_str(ast) print(code) @@ -185,6 +195,8 @@ def test_square_root_2(dtype, instruction_set, padding): config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, cpu_vectorize_info=cpu_vec) ast = ps.create_kernel(up, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) ast.compile() code = ps.get_code_str(ast) @@ -208,6 +220,8 @@ def test_pow(dtype, instruction_set, padding): ps.Assignment(sp.Symbol("xi_2"), sp.Symbol("xi") * sp.Pow(src_field.center, 0.5))] ast = ps.create_kernel(eq, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) ast.compile() code = ps.get_code_str(ast) @@ -235,6 +249,8 @@ def test_issue62(dtype, instruction_set, padding): cpu_vectorize_info=opt) ast = ps.create_kernel(up, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) ast.compile() code = ps.get_code_str(ast) @@ -261,6 +277,8 @@ def test_div_and_unevaluated_expr(dtype, instruction_set): cpu_vectorize_info=opt) ast = ps.create_kernel(up, config=config) + for parameter in ast.get_parameters(): + assert not isinstance(parameter.symbol, CFunction) code = ps.get_code_str(ast) # print(code) -- GitLab