Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (2)
......@@ -744,3 +744,28 @@ class EmptyLine(Node):
def __repr__(self):
return self.__str__()
class ConditionalFieldAccess(sp.Function):
"""
:class:`pystencils.Field.Access` that is only executed if a certain condition is met.
Can be used, for instance, for out-of-bound checks.
"""
def __new__(cls, field_access, outofbounds_condition, outofbounds_value=0):
return sp.Function.__new__(cls, field_access, outofbounds_condition, sp.S(outofbounds_value))
@property
def access(self):
return self.args[0]
@property
def outofbounds_condition(self):
return self.args[1]
@property
def outofbounds_value(self):
return self.args[2]
def __getnewargs__(self):
return self.access, self.outofbounds_condition, self.outofbounds_value
......@@ -426,6 +426,9 @@ class CustomSympyPrinter(CCodePrinter):
)
return code
def _print_ConditionalFieldAccess(self, node):
return self._print(sp.Piecewise((node.outofbounds_value, node.outofbounds_condition), (node.access, True)))
_print_Max = C89CodePrinter._print_Max
_print_Min = C89CodePrinter._print_Min
......
# -*- coding: utf-8 -*-
#
# Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
#
# Distributed under terms of the GPLv3 license.
"""
"""
import itertools
import numpy as np
import pytest
import sympy as sp
import pystencils as ps
from pystencils import Field, x_vector
from pystencils.astnodes import ConditionalFieldAccess
from pystencils.simp import sympy_cse
def add_fixed_constant_boundary_handling(assignments, with_cse):
common_shape = next(iter(set().union(itertools.chain.from_iterable(
[a.atoms(Field.Access) for a in assignments]
)))).field.spatial_shape
ndim = len(common_shape)
def is_out_of_bound(access, shape):
return sp.Or(*[sp.Or(a < 0, a >= s) for a, s in zip(access, shape)])
safe_assignments = [ps.Assignment(
assignment.lhs, assignment.rhs.subs({
a: ConditionalFieldAccess(a, is_out_of_bound(sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
})) for assignment in assignments.all_assignments]
subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
} for assignment in assignments.all_assignments]
print(subs)
if with_cse:
safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments))
return safe_assignments
else:
return ps.AssignmentCollection(safe_assignments)
@pytest.mark.parametrize('with_cse', (False, 'with_cse'))
def test_boundary_check(with_cse):
f, g = ps.fields("f, g : [2D]")
stencil = ps.Assignment(g[0, 0],
(f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
f_arr = np.random.rand(1000, 1000)
g_arr = np.zeros_like(f_arr)
# kernel(f=f_arr, g=g_arr)
assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
print(assignments)
kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile()
print(ps.show_code(kernel_checked))
# No SEGFAULT, please!!
kernel_checked(f=f_arr, g=g_arr)