Skip to content
Snippets Groups Projects
Select Git revision
  • c8dbf8cb8e2ba440dfb64e7a27ba53a27ce6ee08
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_boundary.py

Blame
  • test_global_definitions.py 3.99 KiB
    import sympy
    
    import pystencils.astnodes
    from pystencils.backends.cbackend import CBackend
    from pystencils.data_types import TypedSymbol
    
    
    class BogusDeclaration(pystencils.astnodes.Node):
        """Base class for all AST nodes."""
    
        def __init__(self, parent=None):
            self.parent = parent
    
        @property
        def args(self):
            """Returns all arguments/children of this node."""
            return set()
    
        @property
        def symbols_defined(self):
            """Set of symbols which are defined by this node."""
            return {TypedSymbol('Foo', 'double')}
    
        @property
        def undefined_symbols(self):
            """Symbols which are used but are not defined inside this node."""
            set()
    
        def subs(self, subs_dict):
            """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
            for a in self.args:
                a.subs(subs_dict)
    
        @property
        def func(self):
            return self.__class__
    
        def atoms(self, arg_type):
            """Returns a set of all descendants recursively, which are an instance of the given type."""
            result = set()
            for arg in self.args:
                if isinstance(arg, arg_type):
                    result.add(arg)
                result.update(arg.atoms(arg_type))
            return result
    
    
    class BogusUsage(pystencils.astnodes.Node):
        """Base class for all AST nodes."""
    
        def __init__(self, requires_global: bool, parent=None):
            self.parent = parent
            if requires_global:
                self.required_global_declarations = [BogusDeclaration()]
    
        @property
        def args(self):
            """Returns all arguments/children of this node."""
            return set()
    
        @property
        def symbols_defined(self):
            """Set of symbols which are defined by this node."""
            return set()
    
        @property
        def undefined_symbols(self):
            """Symbols which are used but are not defined inside this node."""
            return {TypedSymbol('Foo', 'double')}
    
        def subs(self, subs_dict):
            """Inplace! substitute, similar to sympy's but modifies the AST inplace."""
            for a in self.args:
                a.subs(subs_dict)
    
        @property
        def func(self):
            return self.__class__
    
        def atoms(self, arg_type):
            """Returns a set of all descendants recursively, which are an instance of the given type."""
            result = set()
            for arg in self.args:
                if isinstance(arg, arg_type):
                    result.add(arg)
                result.update(arg.atoms(arg_type))
            return result
    
    
    def test_global_definitions_with_global_symbol():
        # Teach our printer to print new ast nodes
        CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
        CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
    
        z, x, y = pystencils.fields("z, y, x: [2d]")
    
        normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
            z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
    
        ast = pystencils.create_kernel(normal_assignments)
        print(pystencils.show_code(ast))
        ast.body.append(BogusUsage(requires_global=True))
        print(pystencils.show_code(ast))
        kernel = ast.compile()
        assert kernel is not None
    
        assert TypedSymbol('Foo', 'double') not in [p.symbol for p in ast.get_parameters()]
    
    
    def test_global_definitions_without_global_symbol():
        # Teach our printer to print new ast nodes
        CBackend._print_BogusUsage = lambda _, __: "// Bogus would go here"
        CBackend._print_BogusDeclaration = lambda _, __: "// Declaration would go here"
    
        z, x, y = pystencils.fields("z, y, x: [2d]")
    
        normal_assignments = pystencils.AssignmentCollection([pystencils.Assignment(
            z[0, 0], x[0, 0] * sympy.log(x[0, 0] * y[0, 0]))], [])
    
        ast = pystencils.create_kernel(normal_assignments)
        print(pystencils.show_code(ast))
        ast.body.append(BogusUsage(requires_global=False))
        print(pystencils.show_code(ast))
        kernel = ast.compile()
        assert kernel is not None
    
        assert TypedSymbol('Foo', 'double') in [p.symbol for p in ast.get_parameters()]