You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/12/06 15:33:37 UTC

[GitHub] [tvm] chiwwang commented on a change in pull request #8668: Visualization of Relay IR

chiwwang commented on a change in pull request #8668:
URL: https://github.com/apache/tvm/pull/8668#discussion_r763116588



##########
File path: python/tvm/contrib/relay_viz/__init__.py
##########
@@ -0,0 +1,143 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relay IR Visualizer"""
+from typing import (
+    Dict,
+    Tuple,
+    Union,
+)
+from enum import Enum
+import tvm
+from tvm import relay
+from .plotter import Plotter
+from .node_edge_gen import NodeEdgeGenerator
+
+
+class PlotterBackend(Enum):
+    """Enumeration for available plotter backends."""
+
+    BOKEH = "bokeh"
+    TERMINAL = "terminal"
+
+
+class RelayVisualizer:
+    """Relay IR Visualizer
+
+    Parameters
+    ----------
+    relay_mod : tvm.IRModule
+        Relay IR module.
+    relay_param: None | Dict[str, tvm.runtime.NDArray]
+        Relay parameter dictionary. Default `None`.
+    backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator]
+        The backend used to render graphs. It can be a tuple of an implemented Plotter instance and
+        NodeEdgeGenerator instance to introduce customized parsing and visualization logics.
+        Default ``PlotterBackend.TERMINAL``.
+    """
+
+    def __init__(
+        self,
+        relay_mod: tvm.IRModule,
+        relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None,
+        backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL,
+    ):
+
+        self._plotter, self._ne_generator = get_plotter_and_generator(backend)
+        self._relay_param = relay_param if relay_param is not None else {}
+
+        global_vars = relay_mod.get_global_vars()
+        graph_names = []
+        # If we have main function, put it to the first.
+        # Then main function can be shown on the top.
+        for gv_name in global_vars:
+            if gv_name.name_hint == "main":
+                graph_names.insert(0, gv_name.name_hint)
+            else:
+                graph_names.append(gv_name.name_hint)
+
+        node_to_id = {}
+
+        def traverse_expr(node):
+            if node in node_to_id:
+                return
+            node_to_id[node] = len(node_to_id)
+
+        for name in graph_names:
+            node_to_id.clear()
+            relay.analysis.post_order_visit(relay_mod[name], traverse_expr)
+            graph = self._plotter.create_graph(name)
+            self._add_nodes(graph, node_to_id, self._relay_param)
+
+    def _add_nodes(self, graph, node_to_id, relay_param):
+        """add nodes and to the graph.
+
+        Parameters
+        ----------
+        graph : plotter.Graph
+
+        node_to_id : Dict[relay.expr, str | int]
+
+        relay_param : Dict[str, tvm.runtime.NDarray]
+        """
+        for node in node_to_id:
+            node_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id)
+            if node_info is not None:
+                graph.node(node_info.identity, node_info.type_str, node_info.detail)
+            for edge in edge_info:
+                graph.edge(edge.start, edge.end)
+
+    def render(self, filename: str = None) -> None:
+        self._plotter.render(filename=filename)
+
+
+def get_plotter_and_generator(backend):
+    """Specify the Plottor and its NodeEdgeGenerator"""

Review comment:
       Done.

##########
File path: python/tvm/contrib/relay_viz/node_edge_gen.py
##########
@@ -0,0 +1,250 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`."""
+import abc
+from typing import (
+    Dict,
+    Union,
+    Tuple,
+    List,
+)
+import tvm
+from tvm import relay
+
+UNKNOWN_TYPE = "unknown"
+
+
+class VizNode:
+    """Node carry information used by `plotter.Graph` interface."""

Review comment:
       Done.

##########
File path: python/tvm/contrib/relay_viz/node_edge_gen.py
##########
@@ -0,0 +1,250 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`."""
+import abc
+from typing import (
+    Dict,
+    Union,
+    Tuple,
+    List,
+)
+import tvm
+from tvm import relay
+
+UNKNOWN_TYPE = "unknown"
+
+
+class VizNode:
+    """Node carry information used by `plotter.Graph` interface."""
+
+    def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str):
+        self._id = node_id
+        self._type = node_type
+        self._detail = node_detail
+
+    @property
+    def identity(self) -> Union[int, str]:
+        return self._id
+
+    @property
+    def type_str(self) -> str:
+        return self._type
+
+    @property
+    def detail(self) -> str:
+        return self._detail
+
+
+class VizEdge:
+    """Edges for `plotter.Graph` interface."""
+

Review comment:
       Done.

##########
File path: python/tvm/contrib/relay_viz/node_edge_gen.py
##########
@@ -0,0 +1,250 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`."""
+import abc
+from typing import (
+    Dict,
+    Union,
+    Tuple,
+    List,
+)
+import tvm
+from tvm import relay
+
+UNKNOWN_TYPE = "unknown"
+
+
+class VizNode:
+    """Node carry information used by `plotter.Graph` interface."""
+
+    def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str):
+        self._id = node_id
+        self._type = node_type
+        self._detail = node_detail
+
+    @property
+    def identity(self) -> Union[int, str]:
+        return self._id
+
+    @property
+    def type_str(self) -> str:
+        return self._type
+
+    @property
+    def detail(self) -> str:
+        return self._detail
+
+
+class VizEdge:
+    """Edges for `plotter.Graph` interface."""
+
+    def __init__(self, start_node: Union[int, str], end_node: Union[int, str]):
+        self._start_node = start_node
+        self._end_node = end_node
+
+    @property
+    def start(self) -> Union[int, str]:
+        return self._start_node
+
+    @property
+    def end(self) -> Union[int, str]:
+        return self._end_node
+
+
+class NodeEdgeGenerator(abc.ABC):
+    """An interface class to generate nodes and edges information for Graph interfaces."""
+
+    @abc.abstractmethod
+    def get_node_edges(
+        self,
+        node: relay.Expr,
+        relay_param: Dict[str, tvm.runtime.NDArray],
+        node_to_id: Dict[relay.Expr, Union[int, str]],
+    ) -> Tuple[Union[VizNode, None], List[VizEdge]]:
+        """Generate node and edges consumed by Graph interfaces.
+
+        Parameters
+        ----------
+        node : relay.Expr
+            relay.Expr which will be parsed and generate a node and edges.
+
+        relay_param: Dict[str, tvm.runtime.NDArray]
+            relay parameters dictionary.
+
+        node_to_id : Dict[relay.Expr, Union[int, str]]
+            a mapping from relay.Expr to node id which should be unique.
+
+        Returns
+        -------
+        rv1 : Union[VizNode, None]
+            VizNode represent the relay.Expr. If the relay.Expr is not intended to introduce a node
+            to the graph, return None.
+
+        rv2 : List[VizEdge]
+            a list of VizEdge to describe the connectivity of the relay.Expr.
+            Can be empty list to indicate no connectivity.
+        """
+
+
+class DefaultNodeEdgeGenerator(NodeEdgeGenerator):
+    """NodeEdgeGenerator generate for nodes and edges consumed by Graph.
+    This class is a default implementation for common relay types, heavily based on
+    `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
+    """
+
+    def __init__(self):
+        self._render_rules = {}
+        self._build_rules()
+
+    def get_node_edges(
+        self,
+        node: relay.Expr,
+        relay_param: Dict[str, tvm.runtime.NDArray],
+        node_to_id: Dict[relay.Expr, Union[int, str]],
+    ) -> Tuple[Union[VizNode, None], List[VizEdge]]:
+        try:
+            node_info, edge_info = self._render_rules[type(node)](node, relay_param, node_to_id)
+        except KeyError:
+            node_info = VizNode(
+                node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}"
+            )
+            edge_info = []
+        return node_info, edge_info
+
+    def _var_node(
+        self,
+        node: relay.Expr,
+        relay_param: Dict[str, tvm.runtime.NDArray],
+        node_to_id: Dict[relay.Expr, Union[int, str]],
+    ) -> Tuple[Union[VizNode, None], List[VizEdge]]:
+        """Render rule for a relay var node"""
+        node_id = node_to_id[node]
+        name_hint = node.name_hint
+        node_detail = f"name_hint: {name_hint}"
+        node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)"
+        if node.type_annotation is not None:
+            if hasattr(node.type_annotation, "shape"):
+                shape = tuple(map(int, node.type_annotation.shape))
+                dtype = node.type_annotation.dtype
+                node_detail = f"name_hint: {name_hint}\nshape: {shape}\ndtype: {dtype}"
+            else:
+                node_detail = f"name_hint: {name_hint}\ntype_annotation: {node.type_annotation}"
+        node_info = VizNode(node_id, node_type, node_detail)
+        edge_info = []
+        return node_info, edge_info
+
+    def _function_node(
+        self,
+        node: relay.Expr,
+        _: Dict[str, tvm.runtime.NDArray],  # relay_param

Review comment:
       Done.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org