You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/09/22 00:39:54 UTC

[GitHub] [tvm] areusch commented on a change in pull request #8668: Visualization Relay IR on terminal

areusch commented on a change in pull request #8668:
URL: https://github.com/apache/tvm/pull/8668#discussion_r713507093



##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.

Review comment:
       could you say why?

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        for name in graph_names:
+            # clear previous graph
+            self._node_to_id = {}
+            relay.analysis.post_order_visit(
+                relay_mod[name],
+                lambda node: self._traverse_expr(node),  # pylint: disable=unnecessary-lambda
+            )
+            graph = self._plotter.create_graph(name)
+            # shallow copy to prevent callback modify self._node_to_id
+            self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)
+
+    def _traverse_expr(self, node):
+        # based on https://github.com/apache/tvm/pull/4370
+        if node in self._node_to_id:
+            return
+        self._node_to_id[node] = len(self._node_to_id)
+
+    def _render_cb(self, graph, node_to_id, relay_param):
+        """a callback to Add nodes and edges to the graph.
+
+        Parameters
+        ----------
+        graph : class plotter.Graph
+
+        node_to_id : Dict[relay.expr, int]
+
+        relay_param : Dict[string, NDarray]
+        """
+        # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
+        unknown_type = "unknown"
+        for node, node_id in node_to_id.items():
+            if type(node) in self._render_rules:  # pylint: disable=unidiomatic-typecheck
+                graph_info, edge_info = self._render_rules[type(node)](
+                    node, relay_param, node_to_id
+                )
+                if graph_info:
+                    graph.node(*graph_info)
+                for edge in edge_info:
+                    graph.edge(*edge)
+            else:
+                unknown_info = "Unknown node: {}".format(type(node))
+                _LOGGER.warning(unknown_info)
+                graph.node(node_id, unknown_type, unknown_info)
+
+    def render(self, filename):
+        return self._plotter.render(filename=filename)
+
+
+def get_plotter_and_render_rules(backend):
+    """Specify the Plottor and its render rules
+
+    Parameters
+        ----------
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+    """
+    if type(backend) is tuple and len(backend) == 2:  # pylint: disable=unidiomatic-typecheck
+        if not isinstance(backend[0], Plotter):
+            raise ValueError("First elemnet of the backend should be a plotter")
+        plotter = backend[0]
+        if not isinstance(backend[1], RenderCallback):
+            raise ValueError("Second elemnet of the backend should be a callback")
+        render = backend[1]
+        render_rules = render.get_rules()
+        return plotter, render_rules
+
+    if backend in PlotterBackend:
+        if backend == PlotterBackend.BOKEH:
+            # pylint: disable=import-outside-toplevel

Review comment:
       why do you late-import? add a comment explaining why if you need to do this.

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):

Review comment:
       @kueitang if you have time, would be awesome to add type annotations here

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        for name in graph_names:
+            # clear previous graph

Review comment:
       any reason to carry this as a class variable?

##########
File path: python/tvm/contrib/relay_viz/README.md
##########
@@ -0,0 +1,60 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+# IR Visualization
+
+This tool target to visualize Relay IR.
+
+# Table of Contents
+1. [Requirement](#Requirement)
+2. [Usage](#Usage)
+3. [Credits](#Credits)
+
+## Requirement
+
+1. TVM
+2. graphviz
+2. pydot
+3. bokeh >= 2.3.1
+
+```
+# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html
+
+# requirements of pydot

Review comment:
       i'd ideally like to add these to `python/gen_requirements.py`, but i'm not sure it's the best idea. `bokeh` in particular is pretty heavyweight. before we can do that, we'll need to split the IR parsing stuff into another python package which can be depended on from both this utility and TVM. so for now let's leave them out of `gen_requirements.py`.

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        for name in graph_names:
+            # clear previous graph
+            self._node_to_id = {}
+            relay.analysis.post_order_visit(
+                relay_mod[name],
+                lambda node: self._traverse_expr(node),  # pylint: disable=unnecessary-lambda
+            )
+            graph = self._plotter.create_graph(name)
+            # shallow copy to prevent callback modify self._node_to_id
+            self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)
+
+    def _traverse_expr(self, node):
+        # based on https://github.com/apache/tvm/pull/4370
+        if node in self._node_to_id:
+            return
+        self._node_to_id[node] = len(self._node_to_id)
+
+    def _render_cb(self, graph, node_to_id, relay_param):
+        """a callback to Add nodes and edges to the graph.
+
+        Parameters
+        ----------
+        graph : class plotter.Graph
+
+        node_to_id : Dict[relay.expr, int]
+
+        relay_param : Dict[string, NDarray]
+        """
+        # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
+        unknown_type = "unknown"
+        for node, node_id in node_to_id.items():
+            if type(node) in self._render_rules:  # pylint: disable=unidiomatic-typecheck
+                graph_info, edge_info = self._render_rules[type(node)](
+                    node, relay_param, node_to_id
+                )
+                if graph_info:
+                    graph.node(*graph_info)
+                for edge in edge_info:
+                    graph.edge(*edge)
+            else:
+                unknown_info = "Unknown node: {}".format(type(node))
+                _LOGGER.warning(unknown_info)
+                graph.node(node_id, unknown_type, unknown_info)
+
+    def render(self, filename):
+        return self._plotter.render(filename=filename)
+
+
+def get_plotter_and_render_rules(backend):
+    """Specify the Plottor and its render rules
+
+    Parameters
+        ----------
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+    """
+    if type(backend) is tuple and len(backend) == 2:  # pylint: disable=unidiomatic-typecheck
+        if not isinstance(backend[0], Plotter):
+            raise ValueError("First elemnet of the backend should be a plotter")

Review comment:
       nit: element

##########
File path: python/tvm/contrib/relay_viz/README.md
##########
@@ -0,0 +1,60 @@
+<!--- Licensed to the Apache Software Foundation (ASF) under one -->
+<!--- or more contributor license agreements.  See the NOTICE file -->
+<!--- distributed with this work for additional information -->
+<!--- regarding copyright ownership.  The ASF licenses this file -->
+<!--- to you under the Apache License, Version 2.0 (the -->
+<!--- "License"); you may not use this file except in compliance -->
+<!--- with the License.  You may obtain a copy of the License at -->
+
+<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
+
+<!--- Unless required by applicable law or agreed to in writing, -->
+<!--- software distributed under the License is distributed on an -->
+<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
+<!--- KIND, either express or implied.  See the License for the -->
+<!--- specific language governing permissions and limitations -->
+<!--- under the License. -->
+
+
+# IR Visualization

Review comment:
       @kueitang any interest in adding a tutorial for this in `tutorials/`?

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        for name in graph_names:
+            # clear previous graph
+            self._node_to_id = {}
+            relay.analysis.post_order_visit(
+                relay_mod[name],
+                lambda node: self._traverse_expr(node),  # pylint: disable=unnecessary-lambda
+            )
+            graph = self._plotter.create_graph(name)
+            # shallow copy to prevent callback modify self._node_to_id
+            self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)
+
+    def _traverse_expr(self, node):
+        # based on https://github.com/apache/tvm/pull/4370
+        if node in self._node_to_id:
+            return
+        self._node_to_id[node] = len(self._node_to_id)
+
+    def _render_cb(self, graph, node_to_id, relay_param):
+        """a callback to Add nodes and edges to the graph.
+
+        Parameters
+        ----------
+        graph : class plotter.Graph
+
+        node_to_id : Dict[relay.expr, int]
+
+        relay_param : Dict[string, NDarray]
+        """
+        # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
+        unknown_type = "unknown"
+        for node, node_id in node_to_id.items():
+            if type(node) in self._render_rules:  # pylint: disable=unidiomatic-typecheck
+                graph_info, edge_info = self._render_rules[type(node)](
+                    node, relay_param, node_to_id
+                )
+                if graph_info:
+                    graph.node(*graph_info)
+                for edge in edge_info:
+                    graph.edge(*edge)
+            else:
+                unknown_info = "Unknown node: {}".format(type(node))
+                _LOGGER.warning(unknown_info)
+                graph.node(node_id, unknown_type, unknown_info)
+
+    def render(self, filename):
+        return self._plotter.render(filename=filename)
+
+
+def get_plotter_and_render_rules(backend):
+    """Specify the Plottor and its render rules
+
+    Parameters
+        ----------
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+    """
+    if type(backend) is tuple and len(backend) == 2:  # pylint: disable=unidiomatic-typecheck
+        if not isinstance(backend[0], Plotter):
+            raise ValueError("First elemnet of the backend should be a plotter")
+        plotter = backend[0]
+        if not isinstance(backend[1], RenderCallback):
+            raise ValueError("Second elemnet of the backend should be a callback")
+        render = backend[1]
+        render_rules = render.get_rules()
+        return plotter, render_rules
+
+    if backend in PlotterBackend:

Review comment:
       rather than handle the good case here, handle the bad case and bail:
   
   ```
   if backend not in PlotterBackend:
     raise ValueError(...)
   ```
   
   then the rest of the function can un-indent.

##########
File path: python/tvm/contrib/relay_viz/_bokeh.py
##########
@@ -0,0 +1,484 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Bokeh backend for Relay IR Visualizer."""
+import html
+import logging
+import functools
+
+import numpy as np
+import pydot
+
+from bokeh.io import output_file, save
+from bokeh.models import (
+    ColumnDataSource,
+    CustomJS,
+    Text,
+    Rect,
+    HoverTool,
+    MultiLine,
+    Legend,
+    Scatter,
+    Plot,
+    TapTool,
+    PanTool,
+    ResetTool,
+    WheelZoomTool,
+    SaveTool,
+)
+from bokeh.palettes import (
+    d3,
+)
+from bokeh.layouts import column
+
+from .plotter import (
+    Plotter,
+    Graph,
+)
+
+from .render_callback import RenderCallback  # pylint: disable=import-outside-toplevel
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class BokehRenderCallback(RenderCallback):
+    pass
+
+
+class NodeDescriptor:
+    """Descriptor used by Bokeh plotter."""
+
+    def __init__(self, node_id, node_type, node_detail):
+        self._node_id = node_id
+        self._node_type = node_type
+        self._node_detail = node_detail
+
+    @property
+    def node_id(self):
+        return self._node_id
+
+    @property
+    def node_type(self):
+        return self._node_type
+
+    @property
+    def detail(self):
+        return self._node_detail
+
+
+class GraphShaper:
+    """Provide the bounding-box, and node location, height, width given by pygraphviz."""
+
+    # defined by graphviz.
+    _px_per_inch = 72

Review comment:
       i think style convention is `_PX_PER_INCH`

##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+import logging
+import copy
+from enum import Enum
+from tvm import relay
+from .plotter import Plotter
+from .render_callback import RenderCallback
+
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotters."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer"""
+
+    def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
+        """Visualize Relay IR.
+
+        Parameters
+        ----------
+        relay_mod : object
+                        Relay IR module
+        relay_param: dict
+                        Relay parameter dictionary
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+        """
+
+        self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+        # This field is used for book-keeping for each graph.
+        self._node_to_id = {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        for name in graph_names:
+            # clear previous graph
+            self._node_to_id = {}
+            relay.analysis.post_order_visit(
+                relay_mod[name],
+                lambda node: self._traverse_expr(node),  # pylint: disable=unnecessary-lambda
+            )
+            graph = self._plotter.create_graph(name)
+            # shallow copy to prevent callback modify self._node_to_id
+            self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)
+
+    def _traverse_expr(self, node):
+        # based on https://github.com/apache/tvm/pull/4370
+        if node in self._node_to_id:
+            return
+        self._node_to_id[node] = len(self._node_to_id)
+
+    def _render_cb(self, graph, node_to_id, relay_param):
+        """a callback to Add nodes and edges to the graph.
+
+        Parameters
+        ----------
+        graph : class plotter.Graph
+
+        node_to_id : Dict[relay.expr, int]
+
+        relay_param : Dict[string, NDarray]
+        """
+        # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
+        unknown_type = "unknown"
+        for node, node_id in node_to_id.items():
+            if type(node) in self._render_rules:  # pylint: disable=unidiomatic-typecheck
+                graph_info, edge_info = self._render_rules[type(node)](
+                    node, relay_param, node_to_id
+                )
+                if graph_info:
+                    graph.node(*graph_info)
+                for edge in edge_info:
+                    graph.edge(*edge)
+            else:
+                unknown_info = "Unknown node: {}".format(type(node))
+                _LOGGER.warning(unknown_info)
+                graph.node(node_id, unknown_type, unknown_info)
+
+    def render(self, filename):
+        return self._plotter.render(filename=filename)
+
+
+def get_plotter_and_render_rules(backend):
+    """Specify the Plottor and its render rules
+
+    Parameters
+        ----------
+        backend: PlotterBackend or a tuple
+                        PlotterBackend: The backend of plotting. Default "bokeh"
+                        Tuple: A tuple with two arguments. First is user-defined Plotter, \
+                               the second is user-defined RenderCallback
+    """
+    if type(backend) is tuple and len(backend) == 2:  # pylint: disable=unidiomatic-typecheck

Review comment:
       isinstance(backend, tuple), no? also, should assert the tuple length is 2 rather than allowing tuples of length != 2 right?

##########
File path: python/tvm/contrib/relay_viz/_bokeh.py
##########
@@ -0,0 +1,484 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Bokeh backend for Relay IR Visualizer."""
+import html
+import logging
+import functools
+
+import numpy as np
+import pydot
+
+from bokeh.io import output_file, save
+from bokeh.models import (
+    ColumnDataSource,
+    CustomJS,
+    Text,
+    Rect,
+    HoverTool,
+    MultiLine,
+    Legend,
+    Scatter,
+    Plot,
+    TapTool,
+    PanTool,
+    ResetTool,
+    WheelZoomTool,
+    SaveTool,
+)
+from bokeh.palettes import (
+    d3,
+)
+from bokeh.layouts import column
+
+from .plotter import (
+    Plotter,
+    Graph,
+)
+
+from .render_callback import RenderCallback  # pylint: disable=import-outside-toplevel
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class BokehRenderCallback(RenderCallback):
+    pass
+
+
+class NodeDescriptor:
+    """Descriptor used by Bokeh plotter."""
+
+    def __init__(self, node_id, node_type, node_detail):
+        self._node_id = node_id
+        self._node_type = node_type
+        self._node_detail = node_detail
+
+    @property
+    def node_id(self):
+        return self._node_id
+
+    @property
+    def node_type(self):
+        return self._node_type
+
+    @property
+    def detail(self):
+        return self._node_detail
+
+
+class GraphShaper:
+    """Provide the bounding-box, and node location, height, width given by pygraphviz."""
+
+    # defined by graphviz.
+    _px_per_inch = 72
+
+    def __init__(self, pydot_graph, prog="dot", args=None):
+        if args is None:
+            args = []
+        # call the graphviz program to get layout
+        pydot_graph_str = pydot_graph.create([prog] + args, format="dot").decode()
+        # remember original nodes
+        self._nodes = [n.get_name() for n in pydot_graph.get_nodes()]
+        # parse layout
+        pydot_graph = pydot.graph_from_dot_data(pydot_graph_str)
+        if len(pydot_graph) != 1:
+            # should be unlikely.
+            _LOGGER.warning(
+                "Got %d pydot graphs. Only the first one will be used.", len(pydot_graph)
+            )
+        self._pydot_graph = pydot_graph[0]
+
+    def get_nodes(self):
+        return self._nodes
+
+    @functools.lru_cache()
+    def get_edge_path(self, start_node_id, end_node_id):
+        """Get explicit path points for MultiLine."""
+        edge = self._pydot_graph.get_edge(str(start_node_id), str(end_node_id))
+        if len(edge) != 1:
+            _LOGGER.warning(
+                "Got %d edges between %s and %s. Only the first one will be used.",
+                len(edge),
+                start_node_id,
+                end_node_id,
+            )
+        edge = edge[0]
+        # filter out quotes and newline
+        pos_str = edge.get_pos().strip('"').replace("\\\n", "")
+        tokens = pos_str.split(" ")
+        s_token = None
+        e_token = None
+        ret_x_pts = []
+        ret_y_pts = []
+        for token in tokens:
+            if token.startswith("e,"):
+                e_token = token
+            elif token.startswith("s,"):
+                s_token = token
+            else:
+                x_str, y_str = token.split(",")
+                ret_x_pts.append(float(x_str))
+                ret_y_pts.append(float(y_str))
+        if s_token is not None:
+            _, x_str, y_str = s_token.split(",")
+            ret_x_pts.insert(0, float(x_str))
+            ret_y_pts.insert(0, float(y_str))
+        if e_token is not None:
+            _, x_str, y_str = e_token.split(",")
+            ret_x_pts.append(float(x_str))
+            ret_y_pts.append(float(y_str))
+
+        return ret_x_pts, ret_y_pts
+
+    @functools.lru_cache()
+    def get_node_pos(self, node_name):
+        pos_str = self._get_node_attr(node_name, "pos", "0,0")
+        return list(map(float, pos_str.split(",")))
+
+    def get_node_height(self, node_name):
+        height_str = self._get_node_attr(node_name, "height", "20")
+        return float(height_str) * self._px_per_inch
+
+    def get_node_width(self, node_name):
+        width_str = self._get_node_attr(node_name, "width", "20")
+        return float(width_str) * self._px_per_inch
+
+    def _get_node_attr(self, node_name, attr_name, default_val):
+
+        node = self._pydot_graph.get_node(str(node_name))
+        if len(node) > 1:
+            _LOGGER.error(
+                "There are %d nodes with the name %s. Randomly choose one.", len(node), node_name
+            )
+        if len(node) == 0:
+            _LOGGER.warning(
+                "%s does not exist in the graph. Use default %s for attribute %s",
+                node_name,
+                default_val,
+                attr_name,
+            )
+            return default_val
+
+        node = node[0]
+        try:
+            val = node.obj_dict["attributes"][attr_name].strip('"')
+        except KeyError:
+            _LOGGER.warning(
+                "%s don't exist in node %s. Use default %s", attr_name, node_name, default_val
+            )
+            val = default_val
+        return val
+
+
+class BokehGraph(Graph):
+    """Use Bokeh library to plot Relay IR."""
+
+    def __init__(self):
+        self._pydot_digraph = pydot.Dot(graph_type="digraph")
+        self._id_to_node = {}
+
+    def node(self, node_id, node_type, node_detail):
+        # need string for pydot
+        node_id = str(node_id)
+        if node_id in self._id_to_node:
+            _LOGGER.warning("node_id %s already exists.", node_id)
+            return
+        self._pydot_digraph.add_node(pydot.Node(node_id, label=node_detail))
+        self._id_to_node[node_id] = NodeDescriptor(node_id, node_type, node_detail)
+
+    def edge(self, id_start, id_end):
+        # need string to pydot
+        id_start, id_end = str(id_start), str(id_end)
+        self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end))
+
+    def render(self, plot):
+        """To draw a Bokeh Graph"""
+        shaper = GraphShaper(
+            self._pydot_digraph,
+            prog="dot",
+            args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"],
+        )
+
+        self._create_graph(plot, shaper)
+
+        self._add_scalable_glyph(plot, shaper)
+        return plot
+
+    def _get_type_to_color_map(self):
+        category20 = d3["Category20"][20]
+        # FIXME: a problem is, for different network we have different color
+        # for the same type.
+        all_types = list({v.node_type for v in self._id_to_node.values()})
+        all_types.sort()
+        if len(all_types) > 20:
+            _LOGGER.warning(
+                "The number of types %d is larger than 20. "
+                "Some colors are re-used for different types.",
+                len(all_types),
+            )
+        type_to_color = {}
+        for idx, t in enumerate(all_types):
+            type_to_color[t] = category20[idx % 20]
+        return type_to_color
+
+    def _create_graph(self, plot, shaper):
+
+        # Add edge first
+        edges = self._pydot_digraph.get_edges()
+        x_path_list = []
+        y_path_list = []
+        for edge in edges:
+            id_start = edge.get_source()
+            id_end = edge.get_destination()
+            x_pts, y_pts = shaper.get_edge_path(id_start, id_end)
+            x_path_list.append(x_pts)
+            y_path_list.append(y_pts)
+
+        multi_line_source = ColumnDataSource({"xs": x_path_list, "ys": y_path_list})
+        edge_line_color = "#888888"
+        edge_line_width = 3
+        multi_line_glyph = MultiLine(line_color=edge_line_color, line_width=edge_line_width)
+        plot.add_glyph(multi_line_source, multi_line_glyph)
+
+        # Then add nodes
+        type_to_color = self._get_type_to_color_map()
+
+        def cnvt_to_html(s):
+            return html.escape(s).replace("\n", "<br>")
+
+        label_to_ids = {}
+        for node_id in shaper.get_nodes():
+            label = self._id_to_node[node_id].node_type
+            if label not in label_to_ids:
+                label_to_ids[label] = []
+            label_to_ids[label].append(node_id)
+
+        renderers = []
+        legend_itmes = []
+        for label, id_list in label_to_ids.items():
+            source = ColumnDataSource(
+                {
+                    "x": [shaper.get_node_pos(n)[0] for n in id_list],
+                    "y": [shaper.get_node_pos(n)[1] for n in id_list],
+                    "width": [shaper.get_node_width(n) for n in id_list],
+                    "height": [shaper.get_node_height(n) for n in id_list],
+                    "node_detail": [cnvt_to_html(self._id_to_node[n].detail) for n in id_list],
+                    "node_type": [label] * len(id_list),
+                }
+            )
+            glyph = Rect(fill_color=type_to_color[label])
+            renderer = plot.add_glyph(source, glyph)
+            # set glyph for interactivity
+            renderer.nonselection_glyph = Rect(fill_color=type_to_color[label])
+            renderer.hover_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            renderer.selection_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            # Though it is called "muted_glyph", we actually use it
+            # to emphasize nodes in this renderer.
+            renderer.muted_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            name = f"{self._get_graph_name(plot)}_{label}"
+            renderer.name = name
+            renderers.append(renderer)
+            legend_itmes.append((label, [renderer]))
+
+        # add legend
+        legend = Legend(
+            items=legend_itmes,
+            title="Click to highlight",
+            inactive_fill_color="firebrick",
+            inactive_fill_alpha=0.2,
+        )
+        legend.click_policy = "mute"
+        legend.location = "top_right"
+        plot.add_layout(legend)
+
+        # add tooltips
+        tooltips = [
+            ("node_type", "@node_type"),
+            ("description", "@node_detail{safe}"),
+        ]
+        inspect_tool = WheelZoomTool()
+        # only render nodes
+        hover_tool = HoverTool(tooltips=tooltips, renderers=renderers)
+        plot.add_tools(PanTool(), TapTool(), inspect_tool, hover_tool, ResetTool(), SaveTool())
+        plot.toolbar.active_scroll = inspect_tool
+
+    def _add_scalable_glyph(self, plot, shaper):
+        nodes = shaper.get_nodes()
+
+        def populate_detail(n_type, n_detail):
+            if n_detail:
+                return f"{n_type}\n{n_detail}"
+            return n_type
+
+        text_source = ColumnDataSource(
+            {
+                "x": [shaper.get_node_pos(n)[0] for n in nodes],
+                "y": [shaper.get_node_pos(n)[1] for n in nodes],
+                "text": [self._id_to_node[n].node_type for n in nodes],
+                "detail": [
+                    populate_detail(self._id_to_node[n].node_type, self._id_to_node[n].detail)
+                    for n in nodes
+                ],
+                "box_w": [shaper.get_node_width(n) for n in nodes],
+                "box_h": [shaper.get_node_height(n) for n in nodes],
+            }
+        )
+
+        text_glyph = Text(
+            x="x",
+            y="y",
+            text="text",
+            text_align="center",
+            text_baseline="middle",
+            text_font_size={"value": "14px"},
+        )
+        node_annotation = plot.add_glyph(text_source, text_glyph)
+
+        def get_scatter_loc(x_start, x_end, y_start, y_end, end_node):
+            """return x, y, angle as a tuple"""
+            node_x, node_y = shaper.get_node_pos(end_node)
+            node_w = shaper.get_node_width(end_node)
+            node_h = shaper.get_node_height(end_node)
+
+            # only 4 direction
+            if x_end - x_start > 0:
+                return node_x - node_w / 2, y_end, -np.pi / 2
+            if x_end - x_start < 0:
+                return node_x + node_w / 2, y_end, np.pi / 2
+            if y_end - y_start < 0:
+                return x_end, node_y + node_h / 2, np.pi
+            return x_end, node_y - node_h / 2, 0
+
+        scatter_source = {"x": [], "y": [], "angle": []}
+        for edge in self._pydot_digraph.get_edges():
+            id_start = edge.get_source()
+            id_end = edge.get_destination()
+            x_pts, y_pts = shaper.get_edge_path(id_start, id_end)
+            x_loc, y_loc, angle = get_scatter_loc(
+                x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end
+            )
+            scatter_source["angle"].append(angle)
+            scatter_source["x"].append(x_loc)
+            scatter_source["y"].append(y_loc)
+
+        scatter_glyph = Scatter(
+            x="x",
+            y="y",
+            angle="angle",
+            size=5,
+            marker="triangle",
+            fill_color="#AAAAAA",
+            fill_alpha=0.8,
+        )
+        edge_end_arrow = plot.add_glyph(ColumnDataSource(scatter_source), scatter_glyph)
+
+        plot.y_range.js_on_change(
+            "start",
+            CustomJS(
+                args=dict(
+                    plot=plot,
+                    node_annotation=node_annotation,
+                    text_source=text_source,
+                    edge_end_arrow=edge_end_arrow,
+                ),
+                code="""
+                 // fontsize is in px
+                 var fontsize = 14
+                 // ratio = data_point/px
+                 var ratio = (this.end - this.start)/plot.height
+                 var text_list = text_source.data["text"]
+                 var detail_list = text_source.data["detail"]
+                 var box_h_list = text_source.data["box_h"]
+                 for(var i = 0; i < text_list.length; i++) {
+                     var line_num = Math.floor((box_h_list[i]/ratio) / (fontsize*1.5))
+                     if(line_num <= 0) {
+                         // relieve for the first line
+                         if(Math.floor((box_h_list[i]/ratio) / (fontsize)) > 0) {
+                            line_num = 1
+                         }
+                     }
+                     var lines = detail_list[i].split("\\n")
+                     lines = lines.slice(0, line_num)
+                     text_list[i] = lines.join("\\n")
+                 }
+                 text_source.change.emit()
+
+                 node_annotation.glyph.text_font_size = {value: `${fontsize}px`}
+
+                 var new_scatter_size = Math.round(fontsize / ratio)
+                 edge_end_arrow.glyph.size = {value: new_scatter_size}
+                 """,
+            ),
+        )
+
+    @staticmethod
+    def _get_graph_name(plot):
+        return plot.title
+
+
+class BokehPlotter(Plotter):
+    """Use Bokeh library to plot Relay IR."""
+
+    def __init__(self):
+        self._name_to_graph = {}
+
+    def create_graph(self, name):
+        if name in self._name_to_graph:
+            _LOGGER.warning("Graph name %s exists. ")

Review comment:
       %s but no params

##########
File path: python/tvm/contrib/relay_viz/_bokeh.py
##########
@@ -0,0 +1,484 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Bokeh backend for Relay IR Visualizer."""
+import html
+import logging
+import functools
+
+import numpy as np
+import pydot
+
+from bokeh.io import output_file, save
+from bokeh.models import (
+    ColumnDataSource,
+    CustomJS,
+    Text,
+    Rect,
+    HoverTool,
+    MultiLine,
+    Legend,
+    Scatter,
+    Plot,
+    TapTool,
+    PanTool,
+    ResetTool,
+    WheelZoomTool,
+    SaveTool,
+)
+from bokeh.palettes import (
+    d3,
+)
+from bokeh.layouts import column
+
+from .plotter import (
+    Plotter,
+    Graph,
+)
+
+from .render_callback import RenderCallback  # pylint: disable=import-outside-toplevel
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class BokehRenderCallback(RenderCallback):
+    pass
+
+
+class NodeDescriptor:
+    """Descriptor used by Bokeh plotter."""
+
+    def __init__(self, node_id, node_type, node_detail):
+        self._node_id = node_id
+        self._node_type = node_type
+        self._node_detail = node_detail
+
+    @property
+    def node_id(self):
+        return self._node_id
+
+    @property
+    def node_type(self):
+        return self._node_type
+
+    @property
+    def detail(self):
+        return self._node_detail
+
+
+class GraphShaper:
+    """Provide the bounding-box, and node location, height, width given by pygraphviz."""
+
+    # defined by graphviz.
+    _px_per_inch = 72
+
+    def __init__(self, pydot_graph, prog="dot", args=None):

Review comment:
       can you write docstrings?

##########
File path: python/tvm/contrib/relay_viz/_bokeh.py
##########
@@ -0,0 +1,484 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Bokeh backend for Relay IR Visualizer."""
+import html
+import logging
+import functools
+
+import numpy as np
+import pydot
+
+from bokeh.io import output_file, save
+from bokeh.models import (
+    ColumnDataSource,
+    CustomJS,
+    Text,
+    Rect,
+    HoverTool,
+    MultiLine,
+    Legend,
+    Scatter,
+    Plot,
+    TapTool,
+    PanTool,
+    ResetTool,
+    WheelZoomTool,
+    SaveTool,
+)
+from bokeh.palettes import (
+    d3,
+)
+from bokeh.layouts import column
+
+from .plotter import (
+    Plotter,
+    Graph,
+)
+
+from .render_callback import RenderCallback  # pylint: disable=import-outside-toplevel
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class BokehRenderCallback(RenderCallback):
+    pass
+
+
+class NodeDescriptor:
+    """Descriptor used by Bokeh plotter."""
+
+    def __init__(self, node_id, node_type, node_detail):
+        self._node_id = node_id
+        self._node_type = node_type
+        self._node_detail = node_detail
+
+    @property
+    def node_id(self):
+        return self._node_id
+
+    @property
+    def node_type(self):
+        return self._node_type
+
+    @property
+    def detail(self):
+        return self._node_detail
+
+
+class GraphShaper:
+    """Provide the bounding-box, and node location, height, width given by pygraphviz."""
+
+    # defined by graphviz.
+    _px_per_inch = 72
+
+    def __init__(self, pydot_graph, prog="dot", args=None):
+        if args is None:
+            args = []
+        # call the graphviz program to get layout
+        pydot_graph_str = pydot_graph.create([prog] + args, format="dot").decode()
+        # remember original nodes
+        self._nodes = [n.get_name() for n in pydot_graph.get_nodes()]
+        # parse layout
+        pydot_graph = pydot.graph_from_dot_data(pydot_graph_str)
+        if len(pydot_graph) != 1:
+            # should be unlikely.
+            _LOGGER.warning(
+                "Got %d pydot graphs. Only the first one will be used.", len(pydot_graph)
+            )
+        self._pydot_graph = pydot_graph[0]
+
+    def get_nodes(self):
+        return self._nodes
+
+    @functools.lru_cache()
+    def get_edge_path(self, start_node_id, end_node_id):
+        """Get explicit path points for MultiLine."""
+        edge = self._pydot_graph.get_edge(str(start_node_id), str(end_node_id))
+        if len(edge) != 1:
+            _LOGGER.warning(
+                "Got %d edges between %s and %s. Only the first one will be used.",
+                len(edge),
+                start_node_id,
+                end_node_id,
+            )
+        edge = edge[0]
+        # filter out quotes and newline
+        pos_str = edge.get_pos().strip('"').replace("\\\n", "")
+        tokens = pos_str.split(" ")
+        s_token = None
+        e_token = None
+        ret_x_pts = []
+        ret_y_pts = []
+        for token in tokens:
+            if token.startswith("e,"):
+                e_token = token
+            elif token.startswith("s,"):
+                s_token = token
+            else:
+                x_str, y_str = token.split(",")
+                ret_x_pts.append(float(x_str))
+                ret_y_pts.append(float(y_str))
+        if s_token is not None:
+            _, x_str, y_str = s_token.split(",")
+            ret_x_pts.insert(0, float(x_str))
+            ret_y_pts.insert(0, float(y_str))
+        if e_token is not None:
+            _, x_str, y_str = e_token.split(",")
+            ret_x_pts.append(float(x_str))
+            ret_y_pts.append(float(y_str))
+
+        return ret_x_pts, ret_y_pts
+
+    @functools.lru_cache()
+    def get_node_pos(self, node_name):
+        pos_str = self._get_node_attr(node_name, "pos", "0,0")
+        return list(map(float, pos_str.split(",")))
+
+    def get_node_height(self, node_name):
+        height_str = self._get_node_attr(node_name, "height", "20")
+        return float(height_str) * self._px_per_inch
+
+    def get_node_width(self, node_name):
+        width_str = self._get_node_attr(node_name, "width", "20")
+        return float(width_str) * self._px_per_inch
+
+    def _get_node_attr(self, node_name, attr_name, default_val):
+
+        node = self._pydot_graph.get_node(str(node_name))
+        if len(node) > 1:
+            _LOGGER.error(
+                "There are %d nodes with the name %s. Randomly choose one.", len(node), node_name
+            )
+        if len(node) == 0:
+            _LOGGER.warning(
+                "%s does not exist in the graph. Use default %s for attribute %s",
+                node_name,
+                default_val,
+                attr_name,
+            )
+            return default_val
+
+        node = node[0]
+        try:
+            val = node.obj_dict["attributes"][attr_name].strip('"')
+        except KeyError:
+            _LOGGER.warning(
+                "%s don't exist in node %s. Use default %s", attr_name, node_name, default_val
+            )
+            val = default_val
+        return val
+
+
+class BokehGraph(Graph):
+    """Use Bokeh library to plot Relay IR."""
+
+    def __init__(self):
+        self._pydot_digraph = pydot.Dot(graph_type="digraph")
+        self._id_to_node = {}
+
+    def node(self, node_id, node_type, node_detail):
+        # need string for pydot
+        node_id = str(node_id)
+        if node_id in self._id_to_node:
+            _LOGGER.warning("node_id %s already exists.", node_id)
+            return
+        self._pydot_digraph.add_node(pydot.Node(node_id, label=node_detail))
+        self._id_to_node[node_id] = NodeDescriptor(node_id, node_type, node_detail)
+
+    def edge(self, id_start, id_end):
+        # need string to pydot
+        id_start, id_end = str(id_start), str(id_end)
+        self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end))
+
+    def render(self, plot):
+        """To draw a Bokeh Graph"""
+        shaper = GraphShaper(
+            self._pydot_digraph,
+            prog="dot",
+            args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"],
+        )
+
+        self._create_graph(plot, shaper)
+
+        self._add_scalable_glyph(plot, shaper)
+        return plot
+
+    def _get_type_to_color_map(self):
+        category20 = d3["Category20"][20]
+        # FIXME: a problem is, for different network we have different color
+        # for the same type.
+        all_types = list({v.node_type for v in self._id_to_node.values()})
+        all_types.sort()
+        if len(all_types) > 20:
+            _LOGGER.warning(
+                "The number of types %d is larger than 20. "
+                "Some colors are re-used for different types.",
+                len(all_types),
+            )
+        type_to_color = {}
+        for idx, t in enumerate(all_types):
+            type_to_color[t] = category20[idx % 20]
+        return type_to_color
+
+    def _create_graph(self, plot, shaper):
+
+        # Add edge first
+        edges = self._pydot_digraph.get_edges()
+        x_path_list = []
+        y_path_list = []
+        for edge in edges:
+            id_start = edge.get_source()
+            id_end = edge.get_destination()
+            x_pts, y_pts = shaper.get_edge_path(id_start, id_end)
+            x_path_list.append(x_pts)
+            y_path_list.append(y_pts)
+
+        multi_line_source = ColumnDataSource({"xs": x_path_list, "ys": y_path_list})
+        edge_line_color = "#888888"
+        edge_line_width = 3
+        multi_line_glyph = MultiLine(line_color=edge_line_color, line_width=edge_line_width)
+        plot.add_glyph(multi_line_source, multi_line_glyph)
+
+        # Then add nodes
+        type_to_color = self._get_type_to_color_map()
+
+        def cnvt_to_html(s):
+            return html.escape(s).replace("\n", "<br>")
+
+        label_to_ids = {}
+        for node_id in shaper.get_nodes():
+            label = self._id_to_node[node_id].node_type
+            if label not in label_to_ids:
+                label_to_ids[label] = []
+            label_to_ids[label].append(node_id)
+
+        renderers = []
+        legend_itmes = []
+        for label, id_list in label_to_ids.items():
+            source = ColumnDataSource(
+                {
+                    "x": [shaper.get_node_pos(n)[0] for n in id_list],
+                    "y": [shaper.get_node_pos(n)[1] for n in id_list],
+                    "width": [shaper.get_node_width(n) for n in id_list],
+                    "height": [shaper.get_node_height(n) for n in id_list],
+                    "node_detail": [cnvt_to_html(self._id_to_node[n].detail) for n in id_list],
+                    "node_type": [label] * len(id_list),
+                }
+            )
+            glyph = Rect(fill_color=type_to_color[label])
+            renderer = plot.add_glyph(source, glyph)
+            # set glyph for interactivity
+            renderer.nonselection_glyph = Rect(fill_color=type_to_color[label])
+            renderer.hover_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            renderer.selection_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            # Though it is called "muted_glyph", we actually use it
+            # to emphasize nodes in this renderer.
+            renderer.muted_glyph = Rect(
+                fill_color=type_to_color[label], line_color="firebrick", line_width=3
+            )
+            name = f"{self._get_graph_name(plot)}_{label}"
+            renderer.name = name
+            renderers.append(renderer)
+            legend_itmes.append((label, [renderer]))
+
+        # add legend
+        legend = Legend(
+            items=legend_itmes,
+            title="Click to highlight",
+            inactive_fill_color="firebrick",
+            inactive_fill_alpha=0.2,
+        )
+        legend.click_policy = "mute"
+        legend.location = "top_right"
+        plot.add_layout(legend)
+
+        # add tooltips
+        tooltips = [
+            ("node_type", "@node_type"),
+            ("description", "@node_detail{safe}"),
+        ]
+        inspect_tool = WheelZoomTool()
+        # only render nodes
+        hover_tool = HoverTool(tooltips=tooltips, renderers=renderers)
+        plot.add_tools(PanTool(), TapTool(), inspect_tool, hover_tool, ResetTool(), SaveTool())
+        plot.toolbar.active_scroll = inspect_tool
+
+    def _add_scalable_glyph(self, plot, shaper):
+        nodes = shaper.get_nodes()
+
+        def populate_detail(n_type, n_detail):
+            if n_detail:
+                return f"{n_type}\n{n_detail}"
+            return n_type
+
+        text_source = ColumnDataSource(
+            {
+                "x": [shaper.get_node_pos(n)[0] for n in nodes],
+                "y": [shaper.get_node_pos(n)[1] for n in nodes],
+                "text": [self._id_to_node[n].node_type for n in nodes],
+                "detail": [
+                    populate_detail(self._id_to_node[n].node_type, self._id_to_node[n].detail)
+                    for n in nodes
+                ],
+                "box_w": [shaper.get_node_width(n) for n in nodes],
+                "box_h": [shaper.get_node_height(n) for n in nodes],
+            }
+        )
+
+        text_glyph = Text(
+            x="x",
+            y="y",
+            text="text",
+            text_align="center",
+            text_baseline="middle",
+            text_font_size={"value": "14px"},
+        )
+        node_annotation = plot.add_glyph(text_source, text_glyph)
+
+        def get_scatter_loc(x_start, x_end, y_start, y_end, end_node):
+            """return x, y, angle as a tuple"""
+            node_x, node_y = shaper.get_node_pos(end_node)
+            node_w = shaper.get_node_width(end_node)
+            node_h = shaper.get_node_height(end_node)
+
+            # only 4 direction
+            if x_end - x_start > 0:
+                return node_x - node_w / 2, y_end, -np.pi / 2
+            if x_end - x_start < 0:
+                return node_x + node_w / 2, y_end, np.pi / 2
+            if y_end - y_start < 0:
+                return x_end, node_y + node_h / 2, np.pi
+            return x_end, node_y - node_h / 2, 0
+
+        scatter_source = {"x": [], "y": [], "angle": []}
+        for edge in self._pydot_digraph.get_edges():
+            id_start = edge.get_source()
+            id_end = edge.get_destination()
+            x_pts, y_pts = shaper.get_edge_path(id_start, id_end)
+            x_loc, y_loc, angle = get_scatter_loc(
+                x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end
+            )
+            scatter_source["angle"].append(angle)
+            scatter_source["x"].append(x_loc)
+            scatter_source["y"].append(y_loc)
+
+        scatter_glyph = Scatter(
+            x="x",
+            y="y",
+            angle="angle",
+            size=5,
+            marker="triangle",
+            fill_color="#AAAAAA",
+            fill_alpha=0.8,
+        )
+        edge_end_arrow = plot.add_glyph(ColumnDataSource(scatter_source), scatter_glyph)
+
+        plot.y_range.js_on_change(
+            "start",
+            CustomJS(
+                args=dict(
+                    plot=plot,
+                    node_annotation=node_annotation,
+                    text_source=text_source,
+                    edge_end_arrow=edge_end_arrow,
+                ),
+                code="""
+                 // fontsize is in px
+                 var fontsize = 14
+                 // ratio = data_point/px
+                 var ratio = (this.end - this.start)/plot.height
+                 var text_list = text_source.data["text"]
+                 var detail_list = text_source.data["detail"]
+                 var box_h_list = text_source.data["box_h"]
+                 for(var i = 0; i < text_list.length; i++) {
+                     var line_num = Math.floor((box_h_list[i]/ratio) / (fontsize*1.5))
+                     if(line_num <= 0) {
+                         // relieve for the first line
+                         if(Math.floor((box_h_list[i]/ratio) / (fontsize)) > 0) {
+                            line_num = 1
+                         }
+                     }
+                     var lines = detail_list[i].split("\\n")
+                     lines = lines.slice(0, line_num)
+                     text_list[i] = lines.join("\\n")
+                 }
+                 text_source.change.emit()
+
+                 node_annotation.glyph.text_font_size = {value: `${fontsize}px`}
+
+                 var new_scatter_size = Math.round(fontsize / ratio)
+                 edge_end_arrow.glyph.size = {value: new_scatter_size}
+                 """,
+            ),
+        )
+
+    @staticmethod
+    def _get_graph_name(plot):
+        return plot.title
+
+
+class BokehPlotter(Plotter):
+    """Use Bokeh library to plot Relay IR."""
+
+    def __init__(self):
+        self._name_to_graph = {}
+
+    def create_graph(self, name):
+        if name in self._name_to_graph:
+            _LOGGER.warning("Graph name %s exists. ")
+        else:
+            self._name_to_graph[name] = BokehGraph()
+        return self._name_to_graph[name]
+
+    def render(self, filename):
+        if not filename.endswith(".html"):
+            filename = "{}.html".format(filename)

Review comment:
       nit: can use `f"{filename}.html"`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org