Skip to content
Snippets Groups Projects
Commit cbf8ada5 authored by Frederik Hennig's avatar Frederik Hennig
Browse files

Support ConditionalFieldAccess + some fixes

 - Freeze ConditionalFieldAccess
 - Add freeze for multi-arg sp.And and sp.Or
 - Change config.default_dtype to take a UserTypeSpec
parent cb8e32d1
No related branches found
No related tags found
1 merge request!394Extend symbolic language support
Pipeline #67246 passed
...@@ -7,7 +7,12 @@ import sympy.core.relational ...@@ -7,7 +7,12 @@ import sympy.core.relational
import sympy.logic.boolalg import sympy.logic.boolalg
from sympy.codegen.ast import AssignmentBase, AugmentedAssignment from sympy.codegen.ast import AssignmentBase, AugmentedAssignment
from ...sympyextensions import Assignment, AssignmentCollection, integer_functions from ...sympyextensions import (
Assignment,
AssignmentCollection,
integer_functions,
ConditionalFieldAccess,
)
from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc from ...sympyextensions.typed_sympy import TypedSymbol, CastFunc
from ...sympyextensions.pointers import AddressOf from ...sympyextensions.pointers import AddressOf
from ...field import Field, FieldType from ...field import Field, FieldType
...@@ -353,6 +358,12 @@ class FreezeExpressions: ...@@ -353,6 +358,12 @@ class FreezeExpressions:
else: else:
return PsArrayAccess(ptr, index) return PsArrayAccess(ptr, index)
def map_ConditionalFieldAccess(self, acc: ConditionalFieldAccess):
facc = self.visit_expr(acc.access)
condition = self.visit_expr(acc.outofbounds_condition)
fallback = self.visit_expr(acc.outofbounds_value)
return PsTernary(condition, fallback, facc)
def map_Function(self, func: sp.Function) -> PsExpression: def map_Function(self, func: sp.Function) -> PsExpression:
"""Map SymPy function calls by mapping sympy function classes to backend-supported function symbols. """Map SymPy function calls by mapping sympy function classes to backend-supported function symbols.
...@@ -475,12 +486,12 @@ class FreezeExpressions: ...@@ -475,12 +486,12 @@ class FreezeExpressions:
raise FreezeError(f"Unsupported relation: {other}") raise FreezeError(f"Unsupported relation: {other}")
def map_And(self, conj: sympy.logic.And) -> PsAnd: def map_And(self, conj: sympy.logic.And) -> PsAnd:
arg1, arg2 = [self.visit_expr(arg) for arg in conj.args] args = [self.visit_expr(arg) for arg in conj.args]
return PsAnd(arg1, arg2) return reduce(PsAnd, args) # type: ignore
def map_Or(self, disj: sympy.logic.Or) -> PsOr: def map_Or(self, disj: sympy.logic.Or) -> PsOr:
arg1, arg2 = [self.visit_expr(arg) for arg in disj.args] args = [self.visit_expr(arg) for arg in disj.args]
return PsOr(arg1, arg2) return reduce(PsOr, args) # type: ignore
def map_Not(self, neg: sympy.logic.Not) -> PsNot: def map_Not(self, neg: sympy.logic.Not) -> PsNot:
arg = self.visit_expr(neg.args[0]) arg = self.visit_expr(neg.args[0])
......
...@@ -10,7 +10,7 @@ from .field import Field, FieldType ...@@ -10,7 +10,7 @@ from .field import Field, FieldType
from .backend.jit import JitBase from .backend.jit import JitBase
from .backend.exceptions import PsOptionsError from .backend.exceptions import PsOptionsError
from .types import PsIntegerType, PsNumericType, PsIeeeFloatType from .types import PsIntegerType, UserTypeSpec, PsIeeeFloatType
from .defaults import DEFAULTS from .defaults import DEFAULTS
...@@ -169,7 +169,7 @@ class CreateKernelConfig: ...@@ -169,7 +169,7 @@ class CreateKernelConfig:
index_dtype: PsIntegerType = DEFAULTS.index_dtype index_dtype: PsIntegerType = DEFAULTS.index_dtype
"""Data type used for all index calculations.""" """Data type used for all index calculations."""
default_dtype: PsNumericType = PsIeeeFloatType(64) default_dtype: UserTypeSpec = PsIeeeFloatType(64)
"""Default numeric data type. """Default numeric data type.
This data type will be applied to all untyped symbols. This data type will be applied to all untyped symbols.
......
...@@ -2,6 +2,7 @@ from typing import cast ...@@ -2,6 +2,7 @@ from typing import cast
from .enums import Target from .enums import Target
from .config import CreateKernelConfig from .config import CreateKernelConfig
from .types import create_numeric_type
from .backend import ( from .backend import (
KernelFunction, KernelFunction,
KernelParameter, KernelParameter,
...@@ -53,7 +54,7 @@ def create_kernel( ...@@ -53,7 +54,7 @@ def create_kernel(
""" """
ctx = KernelCreationContext( ctx = KernelCreationContext(
default_dtype=config.default_dtype, index_dtype=config.index_dtype default_dtype=create_numeric_type(config.default_dtype), index_dtype=config.index_dtype
) )
if isinstance(assignments, Assignment): if isinstance(assignments, Assignment):
......
...@@ -163,14 +163,21 @@ def test_freeze_booleans(): ...@@ -163,14 +163,21 @@ def test_freeze_booleans():
x2 = PsExpression.make(ctx.get_symbol("x")) x2 = PsExpression.make(ctx.get_symbol("x"))
y2 = PsExpression.make(ctx.get_symbol("y")) y2 = PsExpression.make(ctx.get_symbol("y"))
z2 = PsExpression.make(ctx.get_symbol("z")) z2 = PsExpression.make(ctx.get_symbol("z"))
w2 = PsExpression.make(ctx.get_symbol("w"))
x, y, z = sp.symbols("x, y, z") x, y, z, w = sp.symbols("x, y, z, w")
expr = freeze(sp.Not(sp.And(x, y)))
assert expr.structurally_equal(PsNot(PsAnd(x2, y2)))
expr = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x))))
assert expr.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2))))
expr1 = freeze(sp.Not(sp.And(x, y))) expr = freeze(sp.And(w, x, y, z))
assert expr1.structurally_equal(PsNot(PsAnd(x2, y2))) assert expr.structurally_equal(PsAnd(PsAnd(PsAnd(w2, x2), y2), z2))
expr2 = freeze(sp.Or(sp.Not(z), sp.And(y, sp.Not(x)))) expr = freeze(sp.Or(w, x, y, z))
assert expr2.structurally_equal(PsOr(PsNot(z2), PsAnd(y2, PsNot(x2)))) assert expr.structurally_equal(PsOr(PsOr(PsOr(w2, x2), y2), z2))
@pytest.mark.parametrize("rel_pair", [ @pytest.mark.parametrize("rel_pair", [
......
...@@ -51,6 +51,9 @@ def add_fixed_constant_boundary_handling(assignments, with_cse): ...@@ -51,6 +51,9 @@ def add_fixed_constant_boundary_handling(assignments, with_cse):
@pytest.mark.parametrize('dtype', ('float64', 'float32')) @pytest.mark.parametrize('dtype', ('float64', 'float32'))
@pytest.mark.parametrize('with_cse', (False, 'with_cse')) @pytest.mark.parametrize('with_cse', (False, 'with_cse'))
def test_boundary_check(dtype, with_cse): def test_boundary_check(dtype, with_cse):
if with_cse:
pytest.xfail("Doesn't typify correctly yet.")
f, g = ps.fields(f"f, g : {dtype}[2D]") f, g = ps.fields(f"f, g : {dtype}[2D]")
stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4) stencil = ps.Assignment(g[0, 0], (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
...@@ -59,7 +62,7 @@ def test_boundary_check(dtype, with_cse): ...@@ -59,7 +62,7 @@ def test_boundary_check(dtype, with_cse):
assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse) assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
config = ps.CreateKernelConfig(data_type=dtype, default_number_float=dtype, ghost_layers=0) config = ps.CreateKernelConfig(default_dtype=ps.create_type(dtype), ghost_layers=0)
kernel_checked = ps.create_kernel(assignments, config=config).compile() kernel_checked = ps.create_kernel(assignments, config=config).compile()
# ps.show_code(kernel_checked) # ps.show_code(kernel_checked)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment