You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by kp...@apache.org on 2022/04/23 15:01:04 UTC

[tvm] branch main updated: [Hexagon] Add mobilenet test (#11104)

This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 822d863770 [Hexagon] Add mobilenet test (#11104)
822d863770 is described below

commit 822d863770f17d0aa2e37fb128438eb4b483d1f1
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Sat Apr 23 08:00:55 2022 -0700

    [Hexagon] Add mobilenet test (#11104)
    
    * Add mobilenet test on Hexagon
    
    * Address comments
    
    * fix import and remove extra function
---
 python/tvm/contrib/hexagon/build.py                | 29 ++++++++
 python/tvm/relay/op/strategy/hexagon.py            | 75 ++++++++++++-------
 python/tvm/topi/hexagon/conv2d.py                  |  8 ++
 python/tvm/topi/hexagon/injective.py               |  8 ++
 tests/python/contrib/test_hexagon/test_launcher.py | 13 ----
 tests/python/contrib/test_hexagon/test_models.py   | 85 ++++++++++++++++++++++
 6 files changed, 178 insertions(+), 40 deletions(-)

diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py
index 1664ee4b11..fa20a2fa7d 100644
--- a/python/tvm/contrib/hexagon/build.py
+++ b/python/tvm/contrib/hexagon/build.py
@@ -257,6 +257,35 @@ class HexagonLauncherRPC(metaclass=abc.ABCMeta):
         graph_mod = self.load_module(module_name, session)
         return tvm.contrib.graph_executor.create(graph_json, graph_mod, session.device)
 
+    def get_graph_debug_executor(
+        self,
+        graph_json: str,
+        module_name: Union[str, pathlib.Path],
+        session: Session,
+        dump_root: Union[str, pathlib.Path] = None,
+    ):
+        """Create a local GraphModuleDebug which consumes a remote libmod.
+
+        Parameters
+        ----------
+        graph_json : str
+            The string with the graph JSON.
+        module_name : str or pathlib.Path
+            Remote module filename. Same restrictions apply as in load_module().
+        session : Session
+            Remote session. The session must be established (via __enter__)
+            prior to calling this function.
+
+        Returns
+        -------
+        GraphModuleDebug :
+            Runtime debug graph module that can be used to debug the graph.
+        """
+        graph_mod = self.load_module(module_name, session)
+        return tvm.contrib.debugger.debug_executor.create(
+            graph_json, graph_mod, session.device, dump_root=str(dump_root)
+        )
+
     def get_aot_executor(self, module_name: Union[str, pathlib.Path], session: Session):
         """Create a local AoTModule which consumes a remote libmod.
 
diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py
index fd5ee97e88..cfd9a8b5dd 100644
--- a/python/tvm/relay/op/strategy/hexagon.py
+++ b/python/tvm/relay/op/strategy/hexagon.py
@@ -22,7 +22,6 @@ from tvm import topi
 from .generic import *
 from .. import op as _op
 
-
 # --- Op strategy registration
 
 
@@ -44,27 +43,49 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target):
     strategy = _op.OpStrategy()
     data_layout = attrs.data_layout
     kernel_layout = attrs.kernel_layout
+    groups = attrs.groups
+    data, kernel = inputs
+    layout = attrs.data_layout
+
+    if groups == 1:
+        if data_layout == "NHWC" and kernel_layout == "HWIO":
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nhwc),
+                wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc),
+                name="conv2d_nhwc.hexagon",
+            )
+        elif data_layout == "NCHW" and kernel_layout == "OIHW":
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.conv2d_nchw),
+                wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw),
+                name="conv2d_nchw.hexagon",
+            )
+        else:
+            raise RuntimeError(
+                f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, "
+                f"groups:{attrs.groups}"
+            )
+    elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
+        if layout == "NCHW":
+            assert kernel_layout == "OIHW"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
+                wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nchw),
+                name="depthwise_conv2d_nchw.generic",
+            )
+        elif layout == "NHWC":
+            assert kernel_layout == "HWOI"
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc),
+                name="depthwise_conv2d_nhwc.generic",
+            )
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
+    else:  # group_conv2d
+        raise RuntimeError(f"Unsupported group_conv2d layout {layout}")
 
-    if data_layout == "NHWC" and kernel_layout == "HWIO":
-        strategy.add_implementation(
-            wrap_compute_conv2d(topi.nn.conv2d_nhwc),
-            wrap_topi_schedule(topi.hexagon.schedule_conv2d_nhwc),
-            name="conv2d_nhwc.hexagon",
-        )
-        return strategy
-
-    if data_layout == "NCHW" and kernel_layout == "OIHW":
-        strategy.add_implementation(
-            wrap_compute_conv2d(topi.nn.conv2d_nchw),
-            wrap_topi_schedule(topi.hexagon.schedule_conv2d_nchw),
-            name="conv2d_nchw.hexagon",
-        )
-        return strategy
-
-    raise RuntimeError(
-        f"Unsupported layouts: data_layout:{data_layout}, kernel_layout:{kernel_layout}, "
-        f"groups:{attrs.groups}"
-    )
+    return strategy
 
 
 @dense_strategy.register("hexagon")
@@ -101,16 +122,16 @@ def schedule_adaptive_pool_hexagon(attrs, outs, target):
         return topi.hexagon.schedule_adaptive_pool(outs)
 
 
-@schedule_concatenate.register("hexagon")
-def schedule_concatenate_hexagon(attrs, outs, target):
-    """Schedule concatenate ops for Hexagon"""
+@schedule_injective.register("hexagon")
+def schedule_injective_hexagon(attrs, outs, target):
+    """Schedule injective ops for Hexagon"""
     with target:
         return topi.hexagon.schedule_injective(outs)
 
 
-@schedule_injective.register("hexagon")
-def schedule_injective_hexagon(attrs, outs, target):
-    """Schedule injective ops for Hexagon"""
+@schedule_concatenate.register("hexagon")
+def schedule_concatenate_hexagon(attrs, outs, target):
+    """Schedule concatenate ops for Hexagon"""
     with target:
         return topi.hexagon.schedule_injective(outs)
 
diff --git a/python/tvm/topi/hexagon/conv2d.py b/python/tvm/topi/hexagon/conv2d.py
index 6df15f8b8c..4f564faa0a 100644
--- a/python/tvm/topi/hexagon/conv2d.py
+++ b/python/tvm/topi/hexagon/conv2d.py
@@ -52,3 +52,11 @@ def schedule_conv2d(outs, layout="NHWC"):
         return schedule_conv2d_nchw(outs)
 
     raise ValueError(f"Unexpected layout={layout}")
+
+
+def schedule_depthwise_conv2d_nchw(outs):
+    return schedule_conv2d_nchw(outs)
+
+
+def schedule_depthwise_conv2d_nhwc(out):
+    return schedule_conv2d_nhwc(out)
diff --git a/python/tvm/topi/hexagon/injective.py b/python/tvm/topi/hexagon/injective.py
index 88e0f40640..34a9fb9a05 100644
--- a/python/tvm/topi/hexagon/injective.py
+++ b/python/tvm/topi/hexagon/injective.py
@@ -42,3 +42,11 @@ def schedule_injective(outs):
 
 def schedule_softmax(outs):
     return schedule_injective(outs)
+
+
+def schedule_elemwise(outs):
+    return schedule_injective(outs)
+
+
+def schedule_broadcast(outs):
+    return schedule_injective(outs)
diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py
index 48b3dac2a2..861ad4f15b 100644
--- a/tests/python/contrib/test_hexagon/test_launcher.py
+++ b/tests/python/contrib/test_hexagon/test_launcher.py
@@ -15,7 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import os
 import sys
 import pytest
 import numpy as np
@@ -24,7 +23,6 @@ import tvm.testing
 from tvm import te
 from tvm import relay
 from tvm.relay.backend import Executor, Runtime
-import tvm.contrib.hexagon as hexagon
 
 from .conftest import requires_hexagon_toolchain
 
@@ -256,17 +254,6 @@ def test_graph_executor_multiple_conv2d(hexagon_session):
     tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)
 
 
-def _workaround_create_aot_shared():
-    # The C codegen uses TVM/RT functions directly. On Hexagon it should use
-    # functions pointers via __TVMxyz variables. This workaround makes the
-    # runtime symbols visible to the compiled shared library.
-    extra_link_flags = os.environ.get("HEXAGON_SHARED_LINK_FLAGS")
-    extra_options = str(extra_link_flags).split() if extra_link_flags else []
-    return lambda so_name, files, hexagon_arch, options: hexagon.create_aot_shared(
-        so_name, files, hexagon_arch, options=extra_options + options
-    )
-
-
 @requires_hexagon_toolchain
 def test_aot_executor(hexagon_session, aot_host_target, aot_target):
     dtype = "float32"
diff --git a/tests/python/contrib/test_hexagon/test_models.py b/tests/python/contrib/test_hexagon/test_models.py
new file mode 100644
index 0000000000..5b4f6059f7
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/test_models.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import sys
+import pytest
+import numpy as np
+
+import tvm.testing
+from tvm import te
+from tvm import relay
+from tvm.relay.backend import Executor, Runtime
+
+from .conftest import requires_hexagon_toolchain
+
+
+@requires_hexagon_toolchain
+def test_mobilenet(hexagon_session):
+    import onnx
+
+    dtype = "float32"
+    model_url = "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx"
+    model_path = tvm.contrib.download.download_testdata(
+        model_url, "mobilenetv2-7.onnx", module="onnx"
+    )
+    onnx_model = onnx.load(model_path)
+
+    target_hexagon = tvm.target.hexagon("v68")
+    target_llvm = tvm.target.Target("llvm")
+    runtime = Runtime("cpp")
+    executor = Executor("graph", {"link-params": True})
+
+    data_in = np.random.rand(1, 3, 224, 224).astype(dtype=dtype)
+
+    input_name = "input"
+    shape_dict = {input_name: data_in.shape}
+    relay_mod, params = relay.frontend.from_onnx(onnx_model, shape_dict, freeze_params=True)
+    inputs = {input_name: data_in}
+
+    with tvm.transform.PassContext(opt_level=3):
+        hexagon_lowered = tvm.relay.build(
+            relay_mod,
+            tvm.target.Target(target_hexagon, host=target_hexagon),
+            runtime=runtime,
+            executor=executor,
+            params=params,
+        )
+
+        llvm_lowered = tvm.relay.build(
+            relay_mod,
+            tvm.target.Target(target_llvm, host=target_llvm),
+            runtime=runtime,
+            executor=executor,
+            params=params,
+        )
+
+    graph_mod = hexagon_session.get_executor_from_factory(hexagon_lowered)
+    graph_mod.set_input(**inputs)
+    graph_mod.run()
+    hexagon_output = graph_mod.get_output(0).numpy()
+
+    llvm_graph_mod = tvm.contrib.graph_executor.GraphModule(llvm_lowered["default"](tvm.cpu(0)))
+    llvm_graph_mod.set_input(**inputs)
+    llvm_graph_mod.run()
+    expected_output = llvm_graph_mod.get_output(0).numpy()
+
+    tvm.testing.assert_allclose(hexagon_output, expected_output, rtol=1e-4, atol=1e-5)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))