Skip to content
Snippets Groups Projects
Select Git revision
  • 3f32ceca0bafb9c93e0fbcae13e2679d29d95dc2
  • master default protected
  • v2.0-dev protected
  • zikeliml/Task-96-dotExporterForAST
  • zikeliml/124-rework-tutorials
  • fma
  • fhennig/v2.0-deprecations
  • holzer-master-patch-46757
  • 66-absolute-access-is-probably-not-copied-correctly-after-_eval_subs
  • gpu_bufferfield_fix
  • hyteg
  • vectorization_sqrt_fix
  • target_dh_refactoring
  • const_fix
  • improved_comm
  • gpu_liveness_opts
  • release/1.3.7 protected
  • release/1.3.6 protected
  • release/2.0.dev0 protected
  • release/1.3.5 protected
  • release/1.3.4 protected
  • release/1.3.3 protected
  • release/1.3.2 protected
  • release/1.3.1 protected
  • release/1.3 protected
  • release/1.2 protected
  • release/1.1.1 protected
  • release/1.1 protected
  • release/1.0.1 protected
  • release/1.0 protected
  • release/0.4.4 protected
  • last/Kerncraft
  • last/OpenCL
  • last/LLVM
  • release/0.4.3 protected
  • release/0.4.2 protected
36 results

test_extensions.py

Blame
  • test_extensions.py 1.89 KiB
    
    import sympy as sp
    
    from pystencils import make_slice, Field, Assignment
    from pystencils.backend.kernelcreation import KernelCreationContext, AstFactory, FullIterationSpace
    from pystencils.backend.transformations import CanonicalizeSymbols, HoistLoopInvariantDeclarations, LowerToC
    from pystencils.backend.literals import PsLiteral
    from pystencils.backend.emission import CAstPrinter
    from pystencils.backend.ast.expressions import PsExpression, PsSubscript
    from pystencils.backend.ast.structural import PsBlock, PsDeclaration
    from pystencils.types.quick import Arr, Int
    
    
    def test_literals():
        ctx = KernelCreationContext()
        factory = AstFactory(ctx)
    
        f = Field.create_generic("f", 3)
        x = sp.Symbol("x")
        
        cells = PsExpression.make(PsLiteral("CELLS", Arr(Int(64), (3,), const=True)))
        global_constant = PsExpression.make(PsLiteral("C", ctx.default_dtype))
    
        loop_slice = make_slice[
            0:PsSubscript(cells, (factory.parse_index(0),)),
            0:PsSubscript(cells, (factory.parse_index(1),)),
            0:PsSubscript(cells, (factory.parse_index(2),)),
        ]
    
        ispace = FullIterationSpace.create_from_slice(ctx, loop_slice)
        ctx.set_iteration_space(ispace)
        
        x_decl = PsDeclaration(factory.parse_sympy(x), global_constant)
    
        loop_body = PsBlock([
            x_decl,
            factory.parse_sympy(Assignment(f.center(), x))
        ])
    
        loops = factory.loops_from_ispace(ispace, loop_body)
        ast = PsBlock([loops])
        
        canon = CanonicalizeSymbols(ctx)
        ast = canon(ast)
    
        hoist = HoistLoopInvariantDeclarations(ctx)
        ast = hoist(ast)
    
        lower = LowerToC(ctx)
        ast = lower(ast)
    
        assert isinstance(ast, PsBlock)
        assert len(ast.statements) == 2
        assert ast.statements[0] == x_decl
    
        code = CAstPrinter()(ast)
        print(code)
    
        assert "const double x = C;" in code
        assert "CELLS[0LL]" in code
        assert "CELLS[1LL]" in code
        assert "CELLS[2LL]" in code