Skip to content
Snippets Groups Projects
Select Git revision
  • fc12b0e5a8ab0e5318e237c52ca841fa4f6c5d98
  • master default protected
  • suffa/cumulantfourth_order_correction_with_psm
  • mr_refactor_wfb
  • Sparse
  • WallLaw
  • improved_comm
  • release/1.3.7
  • release/1.3.6
  • release/1.3.5
  • release/1.3.4
  • release/1.3.3
  • release/1.3.2
  • release/1.3.1
  • release/1.3
  • release/1.2
  • release/1.1.1
  • release/1.1
  • release/1.0.1
  • release/1.0
  • release/0.4.4
  • release/0.4.3
  • release/0.4.2
  • release/0.4.1
  • release/0.4.0
  • release/0.3.4
  • release/0.3.3
27 results

test_moment_transform_api.py

Blame
  • test_conditional_field_access.py 2.15 KiB
    # -*- coding: utf-8 -*-
    #
    # Copyright © 2019 Stephan Seitz <stephan.seitz@fau.de>
    #
    # Distributed under terms of the GPLv3 license.
    
    """
    
    """
    import itertools
    
    import numpy as np
    import pytest
    import sympy as sp
    
    import pystencils as ps
    from pystencils import Field, x_vector
    from pystencils.astnodes import ConditionalFieldAccess
    from pystencils.simp import sympy_cse
    
    
    def add_fixed_constant_boundary_handling(assignments, with_cse):
    
        common_shape = next(iter(set().union(itertools.chain.from_iterable(
            [a.atoms(Field.Access) for a in assignments]
        )))).field.spatial_shape
        ndim = len(common_shape)
    
        def is_out_of_bound(access, shape):
            return sp.Or(*[sp.Or(a < 0, a >= s) for a, s in zip(access, shape)])
    
        safe_assignments = [ps.Assignment(
            assignment.lhs, assignment.rhs.subs({
                a: ConditionalFieldAccess(a, is_out_of_bound(sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
                for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
            })) for assignment in assignments.all_assignments]
    
        subs = [{a: ConditionalFieldAccess(a, is_out_of_bound(
            sp.Matrix(a.offsets) + x_vector(ndim), common_shape))
            for a in assignment.rhs.atoms(Field.Access) if not a.is_absolute_access
        } for assignment in assignments.all_assignments]
        print(subs)
    
        if with_cse:
            safe_assignments = sympy_cse(ps.AssignmentCollection(safe_assignments))
            return safe_assignments
        else:
            return ps.AssignmentCollection(safe_assignments)
    
    
    @pytest.mark.parametrize('with_cse', (False, 'with_cse'))
    def test_boundary_check(with_cse):
    
        f, g = ps.fields("f, g : [2D]")
        stencil = ps.Assignment(g[0, 0],
                                (f[1, 0] + f[-1, 0] + f[0, 1] + f[0, -1]) / 4)
    
        f_arr = np.random.rand(1000, 1000)
        g_arr = np.zeros_like(f_arr)
        # kernel(f=f_arr, g=g_arr)
    
        assignments = add_fixed_constant_boundary_handling(ps.AssignmentCollection([stencil]), with_cse)
    
        print(assignments)
        kernel_checked = ps.create_kernel(assignments, ghost_layers=0).compile()
        print(ps.show_code(kernel_checked))
    
        # No SEGFAULT, please!!
        kernel_checked(f=f_arr, g=g_arr)