You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/02/25 18:27:47 UTC

[tvm] branch main updated: Introduce module_loader to AutoTVM. (#7337)

This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b111695  Introduce module_loader to AutoTVM. (#7337)
b111695 is described below

commit b1116954f532d869a7ce8d9eb24745f368b66e59
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Thu Feb 25 10:27:06 2021 -0800

    Introduce module_loader to AutoTVM. (#7337)
    
    * Introduce code_loader to AutoTVM.
    
     * Prepares for autotuning with microTVM, and provides extension hook
       for VTA.
    
    * add vta hook
    
    * git-black
    
    * pylint
    
    * Add missing import
    
    * Fix import problem
    
    * add missing import
    
    * rename code_loader to module_loader
    
    * rename remote_kw to remote_kwargs
    
    * black format
---
 python/tvm/autotvm/measure/__init__.py        |   8 +-
 python/tvm/autotvm/measure/measure_methods.py | 138 +++++++++++++++++---------
 vta/python/vta/__init__.py                    |   1 +
 vta/python/vta/autotvm.py                     |  52 ++++++++++
 vta/tutorials/autotvm/tune_relay_vta.py       |   1 +
 5 files changed, 150 insertions(+), 50 deletions(-)

diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py
index 0c32ae0..c4c0dc9 100644
--- a/python/tvm/autotvm/measure/__init__.py
+++ b/python/tvm/autotvm/measure/__init__.py
@@ -23,6 +23,12 @@ from .measure import (
     measure_option,
     create_measure_batch,
 )
-from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
+from .measure_methods import (
+    LocalBuilder,
+    LocalRunner,
+    RPCRunner,
+    default_module_loader,
+    request_remote,
+)
 from .executor import Executor
 from .local_executor import LocalExecutor
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index ffe4b97..62fd811 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -22,11 +22,13 @@ These functions are responsible for building the tvm module, uploading it to
 remote devices, recording the running time costs, and checking the correctness of the output.
 """
 
+import contextlib
 import logging
 import shutil
 import os
 import threading
 import time
+import typing
 from random import getrandbits
 from collections import namedtuple
 import tempfile
@@ -199,6 +201,9 @@ class RPCRunner(Runner):
         its actual latency during end-to-end inference.
         To make this option effective, the argument `number` should also be set to 1.
         This is only has effect on CPU task.
+    module_loader : ModuleLoader
+        If given, a context manager that loads the module to be timed into the remote runtime.
+        If not given, default_module_loader is used.
     """
 
     def __init__(
@@ -214,6 +219,7 @@ class RPCRunner(Runner):
         min_repeat_ms=0,
         cooldown_interval=0.1,
         enable_cpu_cache_flush=False,
+        module_loader=None,
     ):
         super(RPCRunner, self).__init__(timeout, n_parallel)
 
@@ -229,6 +235,7 @@ class RPCRunner(Runner):
 
         self.enable_cpu_cache_flush = enable_cpu_cache_flush
         self.cooldown_interval = cooldown_interval
+        self.module_loader = module_loader
 
         self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))
 
@@ -280,6 +287,11 @@ class RPCRunner(Runner):
             for measure_inp, build_res in zip(
                 measure_inputs[i : i + self.n_parallel], build_results[i : i + self.n_parallel]
             ):
+                module_loader = (
+                    self.module_loader
+                    if self.module_loader is not None
+                    else default_module_loader()
+                )
                 ret = self.executor.submit(
                     run_through_rpc,
                     measure_inp,
@@ -290,6 +302,7 @@ class RPCRunner(Runner):
                     self.cooldown_interval,
                     remote_args,
                     self.enable_cpu_cache_flush,
+                    module_loader,
                 )
                 futures.append(ret)
 
@@ -352,6 +365,7 @@ class LocalRunner(RPCRunner):
         min_repeat_ms=0,
         cooldown_interval=0.1,
         enable_cpu_cache_flush=False,
+        module_loader=None,
     ):
         super(LocalRunner, self).__init__(
             "",
@@ -365,6 +379,7 @@ class LocalRunner(RPCRunner):
             min_repeat_ms=min_repeat_ms,
             cooldown_interval=cooldown_interval,
             enable_cpu_cache_flush=enable_cpu_cache_flush,
+            module_loader=module_loader,
         )
         self.tracker = None
         self.server = None
@@ -473,6 +488,11 @@ class _WrappedBuildFunc:
         return BuildResult(filename, arg_info, None, time.time() - tic)
 
 
+ModuleLoader = typing.Callable[
+    [dict, dict], typing.ContextManager[typing.Tuple[tvm.rpc.RPCSession, tvm.runtime.Module]]
+]
+
+
 def run_through_rpc(
     measure_input,
     build_result,
@@ -480,8 +500,9 @@ def run_through_rpc(
     repeat,
     min_repeat_ms,
     cooldown_interval,
-    remote_args,
+    remote_kwargs,
     enable_cpu_cache_flush=False,
+    module_loader=None,
 ):
     """Run a generated library through rpc
 
@@ -509,14 +530,16 @@ def run_through_rpc(
         will be automatically increased.
     cooldown_interval: float
         The cool down interval between two measurements
-    remote_args: Tuple
-        The argument for request_remote
+    remote_kwargs: dict
+        Passed to module_loader(). Ultimately, keyword args to request_remote().
     enable_cpu_cache_flush: bool
         Whether to flush cache on CPU between repeated measurements.
         Flushing cache can make the measured latency of one operator closer to
         its actual latency during end-to-end inference.
         To make this option effective, the argument `number` should also be set to 1.
         This is only has effect on CPU task.
+    module_loader: ModuleLoader
+        A function that returns a ContextManager used to establish and teardown the remote session.
     """
     if isinstance(build_result, MeasureResult):
         return build_result
@@ -525,55 +548,38 @@ def run_through_rpc(
     errno = MeasureErrorNo.NO_ERROR
     try:
         # upload built module
-        remote = request_remote(*remote_args)
-        # Program the FPGA every single time when targeting VTA
-        if (
-            hasattr(measure_input.target, "device_name")
-            and measure_input.target.device_name == "vta"
-        ):
-            # pylint: disable=import-outside-toplevel
-            from vta import program_fpga, reconfig_runtime
-
-            program_fpga(remote, None)
-            reconfig_runtime(remote)
-        remote.upload(build_result.filename)
-        func = remote.load_module(os.path.split(build_result.filename)[1])
-        ctx = remote.context(str(measure_input.target), 0)
-
-        # Limitation:
-        # We can not get PackFunction directly in the remote mode as it is wrapped
-        # under the std::function. We could lift the restriction later once we fold
-        # the PackedFunc as an object. Currently, we pass function name to work
-        # around it.
-        f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
-        time_f = func.time_evaluator(
-            func.entry_name,
-            ctx,
-            number=number,
-            repeat=repeat,
-            min_repeat_ms=min_repeat_ms,
-            f_preproc=f_prepare,
-        )
-
-        try:
-            random_fill = remote.get_function("tvm.contrib.random.random_fill")
-        except AttributeError:
-            raise AttributeError(
-                "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
+        with module_loader(remote_kwargs, build_result) as (remote, mod):
+            ctx = remote.context(str(measure_input.target), 0)
+
+            # Limitation:
+            # We can not get PackFunction directly in the remote mode as it is wrapped
+            # under the std::function. We could lift the restriction later once we fold
+            # the PackedFunc as an object. Currently, we pass function name to work
+            # around it.
+            f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
+            time_f = mod.time_evaluator(
+                mod.entry_name,
+                ctx,
+                number=number,
+                repeat=repeat,
+                min_repeat_ms=min_repeat_ms,
+                f_preproc=f_prepare,
             )
-        args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
-        if "scatter" not in measure_input.task.name:
-            # the index tensor of scatter op cannot be randomly initialized
-            for arg in args:
-                random_fill(arg)
-        ctx.sync()
 
-        costs = time_f(*args).results
+            try:
+                random_fill = remote.get_function("tvm.contrib.random.random_fill")
+            except AttributeError:
+                raise AttributeError(
+                    "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
+                )
+            args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
+            if "scatter" not in measure_input.task.name:
+                # the index tensor of scatter op cannot be randomly initialized
+                for arg in args:
+                    random_fill(arg)
+            ctx.sync()
 
-        # clean up remote files
-        remote.remove(build_result.filename)
-        remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
-        remote.remove("")
+            costs = time_f(*args).results
 
         if len(costs) > 2:  # remove largest and smallest value to reduce variance
             costs = list(costs)
@@ -592,6 +598,40 @@ def run_through_rpc(
     return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
 
 
+def default_module_loader(pre_load_function=None):
+    """Returns a default function that can be passed as module_loader to run_through_rpc.
+
+    Parameters
+    ----------
+    pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
+        Invoked after a session is established and before the default code-loading RPC calls are
+        issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
+
+    Returns
+    -------
+    ModuleLoader :
+        A function that can be passed as module_loader to run_through_rpc.
+    """
+
+    @contextlib.contextmanager
+    def default_module_loader_mgr(remote_kwargs, build_result):
+        remote = request_remote(**remote_kwargs)
+        if pre_load_function is not None:
+            pre_load_function(remote, build_result)
+
+        remote.upload(build_result.filename)
+        try:
+            yield remote, remote.load_module(os.path.split(build_result.filename)[1])
+
+        finally:
+            # clean up remote files
+            remote.remove(build_result.filename)
+            remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
+            remote.remove("")
+
+    return default_module_loader_mgr
+
+
 def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
     """Request a remote session
 
diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py
index d143c4d..5fce768 100644
--- a/vta/python/vta/__init__.py
+++ b/vta/python/vta/__init__.py
@@ -22,6 +22,7 @@ configure the hardware environment and access remote device through RPC.
 """
 import sys
 
+from .autotvm import module_loader
 from .bitstream import get_bitstream_path, download_bitstream
 from .environment import get_env, Environment
 from .rpc_client import reconfig_runtime, program_fpga
diff --git a/vta/python/vta/autotvm.py b/vta/python/vta/autotvm.py
new file mode 100644
index 0000000..9aa7390
--- /dev/null
+++ b/vta/python/vta/autotvm.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines AutoTVM components used with VTA."""
+
+from tvm.autotvm.measure import default_module_loader
+from . import rpc_client
+
+
+def module_loader(bitstream=None):
+    """Construct a ModuleLoader implementation specialized for VTA.
+
+    Parameters
+    ----------
+    bitsream : Optional[str]
+        Path to the bitstream to write prior to uploading code.
+
+    Returns
+    -------
+    ModuleLoader :
+        The ModuleLoader instance.
+    """
+
+    def reprogram_fpga(remote, _build_result):
+        """default_module_loader callback which reprograms the FPGA.
+
+        Parameters
+        ----------
+        remote : tvm.rpc.RPCSession
+            RPC session established to the remote device.
+
+        _build_result : tvm.autotvm.measure.measure_methods.BuildResult
+            Artifact from the build phase, unused here.
+        """
+        rpc_client.program_bitstream(remote, bitstream)
+        rpc_client.reconfig_runtime(remote)
+
+    return default_module_loader(reprogram_fpga)
diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py
index c5885b6..ed2671c 100644
--- a/vta/tutorials/autotvm/tune_relay_vta.py
+++ b/vta/tutorials/autotvm/tune_relay_vta.py
@@ -215,6 +215,7 @@ tuning_option = {
             port=tracker_port,
             number=5,
             timeout=60,
+            module_loader=vta.module_loader(),
             # check_correctness=True, # TODO: re-enable when check_correctness works again.
         ),
     ),