From e7ace8fe843084f9b7f4392028c782a7f758ad2b Mon Sep 17 00:00:00 2001
From: Frederik Hennig <frederik.hennig@fau.de>
Date: Thu, 12 Dec 2024 14:19:21 +0100
Subject: [PATCH] update AST inspection to allow disabling certain tabs. Add a
 note on cupy.

---
 docs/source/reference/gpu_kernels.md    | 21 +++++-
 docs/source/reference/kernelcreation.md |  2 +-
 src/pystencils/inspection.py            | 95 ++++++++++++++++++-------
 3 files changed, 90 insertions(+), 28 deletions(-)

diff --git a/docs/source/reference/gpu_kernels.md b/docs/source/reference/gpu_kernels.md
index 1045f80d4..1e9456bf7 100644
--- a/docs/source/reference/gpu_kernels.md
+++ b/docs/source/reference/gpu_kernels.md
@@ -58,10 +58,22 @@ property, which tells us how many threads the kernel is expecting to be executed
 kernel.threads_range
 ```
 
-If a GPU is available and [cupy] is installed in the current environment,
+If a GPU is available and [CuPy][cupy] is installed in the current environment,
 the kernel can be compiled and run immediately.
-To execute the kernel, a {any}`cupy.ndarray` has to be passed for each field;
-this is the GPU analogue to {any}`numpy.ndarray`:
+To execute the kernel, a {any}`cupy.ndarray` has to be passed for each field.
+
+:::{note}
+[CuPy][cupy] is a Python library for numerical computations on GPU arrays,
+which operates much in the same way that [NumPy][numpy] works on CPU arrays.
+Cupy and NumPy expose nearly the same APIs for array operations;
+the difference being that CuPy allocates all its arrays on the GPU
+and performs its operations as CUDA kernels.
+Also, CuPy exposes a just-in-time-compiler for GPU kernels, which internally calls [nvcc].
+In pystencils, we use CuPy both to compile and provide executable kernels on-demand from within Python code,
+and to allocate and manage the data these kernels can be executed on.
+
+For more information on CuPy, refer to [their documentation][cupy-docs].
+:::
 
 ```{code-cell} ipython3
 :tags: [raises-exception]
@@ -212,3 +224,6 @@ only a part of the triangle is being processed.
 
 
 [cupy]: https://cupy.dev "CuPy Homepage"
+[numpy]: https://numpy.org "NumPy Homepage"
+[nvcc]: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html "NVIDIA CUDA Compiler Driver"
+[cupy-docs]: https://docs.cupy.dev/en/stable/overview.html "CuPy Documentation"
\ No newline at end of file
diff --git a/docs/source/reference/kernelcreation.md b/docs/source/reference/kernelcreation.md
index fc0b2248e..af8c01456 100644
--- a/docs/source/reference/kernelcreation.md
+++ b/docs/source/reference/kernelcreation.md
@@ -178,7 +178,7 @@ are using the `int32` data type, as specified in {py:data}`index_dtype <CreateKe
 
 driver = ps.kernelcreation.get_driver(cfg, retain_intermediates=True)
 kernel = driver(assignments)
-ps.inspect(driver.intermediates.materialized_ispace)
+ps.inspect(driver.intermediates.materialized_ispace, show_cpp=False)
 ```
 
 :::{note}
diff --git a/src/pystencils/inspection.py b/src/pystencils/inspection.py
index 7fa3047c6..cb03a1c8d 100644
--- a/src/pystencils/inspection.py
+++ b/src/pystencils/inspection.py
@@ -102,20 +102,31 @@ class AstInspection(CodeInspectionBase):
     explore an abstract syntax tree.
     """
 
-    def __init__(self, ast: PsAstNode):
+    def __init__(
+        self,
+        ast: PsAstNode,
+        show_ir: bool = True,
+        show_cpp: bool = True,
+        show_graph: bool = True,
+    ):
         super().__init__()
         self._ast = ast
+        self._show_ir = show_ir
+        self._show_cpp = show_cpp
+        self._show_graph = show_graph
 
     def _widget(self):
         import ipywidgets as widgets
 
-        tabs = widgets.Tab(
-            children=[
-                self._ir_tab(self._ast),
-                self._cpp_tab(self._ast),
-                self._graphviz_tab(self._ast),
-            ]
-        )
+        tabs = []
+        if self._show_ir:
+            tabs.append(self._ir_tab(self._ast))
+        if self._show_cpp:
+            tabs.append(self._cpp_tab(self._ast))
+        if self._show_graph:
+            tabs.append(self._graphviz_tab(self._ast))
+
+        tabs = widgets.Tab(children=tabs)
         tabs.titles = ["IR Code", "C Code", "AST Visualization"]
 
         tabs.layout.height = "250pt"
@@ -124,20 +135,31 @@ class AstInspection(CodeInspectionBase):
 
 
 class KernelInspection(CodeInspectionBase):
-    def __init__(self, kernel: KernelFunction) -> None:
+    def __init__(
+        self,
+        kernel: KernelFunction,
+        show_ir: bool = True,
+        show_cpp: bool = True,
+        show_graph: bool = True,
+    ) -> None:
         super().__init__()
         self._kernel = kernel
+        self._show_ir = show_ir
+        self._show_cpp = show_cpp
+        self._show_graph = show_graph
 
     def _widget(self):
         import ipywidgets as widgets
 
-        tabs = widgets.Tab(
-            children=[
-                self._ir_tab(self._kernel),
-                self._cpp_tab(self._kernel),
-                self._graphviz_tab(self._kernel),
-            ]
-        )
+        tabs = []
+        if self._show_ir:
+            tabs.append(self._ir_tab(self._kernel))
+        if self._show_cpp:
+            tabs.append(self._cpp_tab(self._kernel))
+        if self._show_graph:
+            tabs.append(self._graphviz_tab(self._kernel))
+
+        tabs = widgets.Tab(children=tabs)
         tabs.titles = ["IR Code", "C Code", "AST Visualization"]
 
         tabs.layout.height = "250pt"
@@ -146,8 +168,17 @@ class KernelInspection(CodeInspectionBase):
 
 
 class IntermediatesInspection:
-    def __init__(self, intermediates: CodegenIntermediates):
+    def __init__(
+        self,
+        intermediates: CodegenIntermediates,
+        show_ir: bool = True,
+        show_cpp: bool = True,
+        show_graph: bool = True,
+    ):
         self._intermediates = intermediates
+        self._show_ir = show_ir
+        self._show_cpp = show_cpp
+        self._show_graph = show_graph
 
     def _ipython_display_(self):
         from IPython.display import display
@@ -155,7 +186,15 @@ class IntermediatesInspection:
 
         stages = self._intermediates.available_stages
 
-        previews: list[AstInspection] = [AstInspection(stage.ast) for stage in stages]
+        previews: list[AstInspection] = [
+            AstInspection(
+                stage.ast,
+                show_ir=self._show_ir,
+                show_cpp=self._show_cpp,
+                show_graph=self._show_graph,
+            )
+            for stage in stages
+        ]
         labels: list[str] = [stage.label for stage in stages]
 
         code_views = [p._widget() for p in previews]
@@ -201,9 +240,9 @@ def inspect(obj: StageResult): ...
 def inspect(obj: CodegenIntermediates): ...
 
 
-def inspect(obj):
+def inspect(obj, show_ir: bool = True, show_cpp: bool = True, show_graph: bool = True):
     """Interactively inspect various products of the code generator.
-    
+
     When run inside a Jupyter notebook, this function displays an inspection widget
     for the following types of objects:
     - `PsAstNode`
@@ -216,13 +255,21 @@ def inspect(obj):
 
     match obj:
         case PsAstNode():
-            preview = AstInspection(obj)
+            preview = AstInspection(
+                obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
+            )
         case KernelFunction():
-            preview = KernelInspection(obj)
+            preview = KernelInspection(
+                obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
+            )
         case StageResult(ast, _):
-            preview = AstInspection(ast)
+            preview = AstInspection(
+                ast, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
+            )
         case CodegenIntermediates():
-            preview = IntermediatesInspection(obj)
+            preview = IntermediatesInspection(
+                obj, show_ir=show_ir, show_cpp=show_cpp, show_graph=show_cpp
+            )
         case _:
             raise ValueError(f"Cannot inspect object of type {type(obj)}")
 
-- 
GitLab