Skip to content
Snippets Groups Projects
Commit 0834b5c9 authored by Markus Holzer's avatar Markus Holzer
Browse files

Debug checking

parents 64e167e9 f4218729
Branches
Tags
1 merge request!275WIP: Revamp the type system
Source diff could not be displayed: it is too large. Options to address this: view the blob.
import graphviz import graphviz
from graphviz import Digraph, lang try:
from graphviz import Digraph
import graphviz.quoting as quote
except ImportError:
from graphviz import Digraph
import graphviz.lang as quote
from sympy.printing.printer import Printer from sympy.printing.printer import Printer
...@@ -12,7 +17,7 @@ class DotPrinter(Printer): ...@@ -12,7 +17,7 @@ class DotPrinter(Printer):
super(DotPrinter, self).__init__() super(DotPrinter, self).__init__()
self._node_to_str_function = node_to_str_function self._node_to_str_function = node_to_str_function
self.dot = Digraph(**kwargs) self.dot = Digraph(**kwargs)
self.dot.quote_edge = lang.quote self.dot.quote_edge = quote.quote
def _print_KernelFunction(self, func): def _print_KernelFunction(self, func):
self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func)) self.dot.node(str(id(func)), style='filled', fillcolor='#a056db', label=self._node_to_str_function(func))
......
...@@ -10,7 +10,12 @@ from pystencils.kernel_wrapper import KernelWrapper ...@@ -10,7 +10,12 @@ from pystencils.kernel_wrapper import KernelWrapper
def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True): def to_dot(expr: sp.Expr, graph_style: Optional[Dict[str, Any]] = None, short=True):
"""Show a sympy or pystencils AST as dot graph""" """Show a sympy or pystencils AST as dot graph"""
from pystencils.astnodes import Node from pystencils.astnodes import Node
import graphviz try:
import graphviz
except ImportError:
print("graphviz is not installed. Visualizing the AST is not available")
return
graph_style = {} if graph_style is None else graph_style graph_style = {} if graph_style is None else graph_style
if isinstance(expr, Node): if isinstance(expr, Node):
......
...@@ -33,7 +33,8 @@ def test_two_arguments(dtype, func, target): ...@@ -33,7 +33,8 @@ def test_two_arguments(dtype, func, target):
dh.run_kernel(kernel) dh.run_kernel(kernel)
dh.all_to_cpu() dh.all_to_cpu()
np.testing.assert_allclose(dh.gather_array("x")[0, 0], float(func(1.0, 2.0).evalf())) np.testing.assert_allclose(dh.gather_array("x")[0, 0], float(func(1.0, 2.0).evalf()),
13 if dtype == 'float64' else 5)
@pytest.mark.parametrize('dtype', ["float64", "float32"]) @pytest.mark.parametrize('dtype', ["float64", "float32"])
......
...@@ -77,6 +77,7 @@ def test_strided(instruction_set, dtype): ...@@ -77,6 +77,7 @@ def test_strided(instruction_set, dtype):
default_number_float=npdtype) default_number_float=npdtype)
ast = ps.create_kernel(update_rule, config=config) ast = ps.create_kernel(update_rule, config=config)
assert len(warn) == 0 assert len(warn) == 0
ps.show_code(ast) ps.show_code(ast)
func = ast.compile() func = ast.compile()
ref_func = ps.create_kernel(update_rule).compile() ref_func = ps.create_kernel(update_rule).compile()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment