From cbf8ada5556f3f341626da4f3999a6149e9c3f3a Mon Sep 17 00:00:00 2001 From: Frederik Hennig <frederik.hennig@fau.de> Date: Tue, 2 Jul 2024 17:07:11 +0200 Subject: [PATCH] Support ConditionalFieldAccess + some fixes - Freeze ConditionalFieldAccess - Add freeze for multi-arg sp.And and sp.Or - Change config.default_dtype to take a UserTypeSpec --- .../backend/kernelcreation/freeze.py | 21 ++++++++++++++----- src/pystencils/config.py | 4 ++-- src/pystencils/kernelcreation.py | 3 ++- tests/nbackend/kernelcreation/test_freeze.py | 17 ++++++++++----- .../test_conditional_field_access.py | 5 ++++- 5 files changed, 36 insertions(+), 14 deletions(-) rename tests/{ => symbolics}/test_conditional_field_access.py (93%) diff --git a/src/pystencils/backend/kernelcreation/freeze.py b/src/pystencils/backend/kernelcreation/freeze.py index a73c23279..1e1355bbf 100644 --- a/src/pystencils/backend/kernelcreation/freeze.py +++ b/src/pystencils/backend/kernelcreation/freeze.py @@ -7,7 +7,12 @@ import sympy.core.relational import sympy.logic.boolalg from sympy.codegen.ast import AssignmentBase, AugmentedAssignment -from ...sympyextensions import Assignment, AssignmentCollection, integer_functions +from ...sympyextensions import ( + Assignment, + AssignmentCollection, + integer_functions, + ConditionalFieldAccess, +) from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc from ...sympyextensions.pointers import AddressOf from ...field import Field, FieldType @@ -353,6 +358,12 @@ class FreezeExpressions: else: return PsArrayAccess(ptr, index) + def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess): + facc = self.visit_expr(acc.access) + condition = self.visit_expr(acc.outofbounds_condition) + fallback = self.visit_expr(acc.outofbounds_value) + return PsTernary(condition, fallback, facc) + def map_Function(self, func: sp.Function) -> PsExpression: """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. @@ -475,12 +486,12 @@ class FreezeExpressions: raise FreezeError(f"Unsupported relation: {other}") def map_And(self, conj: sympy.logic.And) -> PsAnd: - arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] - return PsAnd(arg1, arg2) + args = [self.visit_expr(arg) for arg in conj.args] + return reduce(PsAnd, args) # type: ignore def map_Or(self, disj: sympy.logic.Or) -> PsOr: - arg1, arg2 = [self.visit_expr(arg) for arg in disj.args] - return PsOr(arg1, arg2) + args = [self.visit_expr(arg) for arg in disj.args] + return reduce(PsOr, args) # type: ignore def map_Not(self, neg: sympy.logic.Not) -> PsNot: arg = self.visit_expr(neg.args[0]) diff --git a/src/pystencils/config.py b/src/pystencils/config.py index 7c49c4c37..fe0e87900 100644 --- a/src/pystencils/config.py +++ b/src/pystencils/config.py @@ -10,7 +10,7 @@ from .field import Field, FieldType from .backend.jit import JitBase from .backend.exceptions import PsOptionsError -from .types import PsIntegerType, PsNumericType, PsIeeeFloatType +from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType from .defaults import DEFAULTS @@ -169,7 +169,7 @@ class CreateKernelConfig: index_dtype: PsIntegerType = DEFAULTS.index_dtype """Data type used for all index calculations.""" - default_dtype: PsNumericType = PsIeeeFloatType(64) + default_dtype: UserTypeSpec = PsIeeeFloatType(64) """Default numeric data type. This data type will be applied to all untyped symbols. diff --git a/src/pystencils/kernelcreation.py b/src/pystencils/kernelcreation.py index 66c2a0d6c..3cda5aa46 100644 --- a/src/pystencils/kernelcreation.py +++ b/src/pystencils/kernelcreation.py @@ -2,6 +2,7 @@ from typing import cast from .enums import Target from .config import CreateKernelConfig +from .types import create_numeric_type from .backend import ( KernelFunction, KernelParameter, @@ -53,7 +54,7 @@ def create_kernel( """ ctx = KernelCreationContext( - default_dtype=config.default_dtype, index_dtype=config.index_dtype + default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype ) if isinstance(assignments, Assignment): diff --git a/tests/nbackend/kernelcreation/test_freeze.py b/tests/nbackend/kernelcreation/test_freeze.py index 1593be684..b1e8525b1 100644 --- a/tests/nbackend/kernelcreation/test_freeze.py +++ b/tests/nbackend/kernelcreation/test_freeze.py @@ -163,14 +163,21 @@ def test_freeze_booleans(): x2 = PsExpression.make(ctx.get_symbol("x")) y2 = PsExpression.make(ctx.get_symbol("y")) z2 = PsExpression.make(ctx.get_symbol("z")) + w2 = PsExpression.make(ctx.get_symbol("w")) - x, y, z = sp.symbols("x, y, z") + x, y, z, w = sp.symbols("x, y, z, w") + + expr = freeze(sp.Not(sp.And(x, y))) + assert expr.structurally_equal(PsNot(PsAnd(x2, y2))) + + expr = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) + assert expr.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) - expr1 = freeze(sp.Not(sp.And(x, y))) - assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) + expr = freeze(sp.And(w, x, y, z)) + assert expr.structurally_equal(PsAnd(PsAnd(PsAnd(w2, x2), y2), z2)) - expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) - assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) + expr = freeze(sp.Or(w, x, y, z)) + assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2)) @pytest.mark.parametrize("rel_pair", [ diff --git a/tests/test_conditional_field_access.py b/tests/symbolics/test_conditional_field_access.py similarity index 93% rename from tests/test_conditional_field_access.py rename to tests/symbolics/test_conditional_field_access.py index 1e120304b..bd384a959 100644 --- a/tests/test_conditional_field_access.py +++ b/tests/symbolics/test_conditional_field_access.py @@ -51,6 +51,9 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): @pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('with_cse', (False, 'with_cse')) def test_boundary_check(dtype, with_cse): + if with_cse: + pytest.xfail("Doesn't typify correctly yet.") + f, g = ps.fields(f"f, g : {dtype}[2D]") stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) @@ -59,7 +62,7 @@ def test_boundary_check(dtype, with_cse): assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) - config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0) + config = ps.CreateKernelConfig(default_dtype=ps.create_type(dtype), ghost_layers=0) kernel_checked = ps.create_kernel(assignments, config=config).compile() # ps.show_code(kernel_checked) -- GitLab