You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mo...@apache.org on 2021/02/25 18:27:47 UTC
[tvm] branch main updated: Introduce module_loader to AutoTVM.
(#7337)
This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b111695 Introduce module_loader to AutoTVM. (#7337)
b111695 is described below
commit b1116954f532d869a7ce8d9eb24745f368b66e59
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Thu Feb 25 10:27:06 2021 -0800
Introduce module_loader to AutoTVM. (#7337)
* Introduce code_loader to AutoTVM.
* Prepares for autotuning with microTVM, and provides extension hook
for VTA.
* add vta hook
* git-black
* pylint
* Add missing import
* Fix import problem
* add missing import
* rename code_loader to module_loader
* rename remote_kw to remote_kwargs
* black format
---
python/tvm/autotvm/measure/__init__.py | 8 +-
python/tvm/autotvm/measure/measure_methods.py | 138 +++++++++++++++++---------
vta/python/vta/__init__.py | 1 +
vta/python/vta/autotvm.py | 52 ++++++++++
vta/tutorials/autotvm/tune_relay_vta.py | 1 +
5 files changed, 150 insertions(+), 50 deletions(-)
diff --git a/python/tvm/autotvm/measure/__init__.py b/python/tvm/autotvm/measure/__init__.py
index 0c32ae0..c4c0dc9 100644
--- a/python/tvm/autotvm/measure/__init__.py
+++ b/python/tvm/autotvm/measure/__init__.py
@@ -23,6 +23,12 @@ from .measure import (
measure_option,
create_measure_batch,
)
-from .measure_methods import LocalBuilder, LocalRunner, RPCRunner, request_remote
+from .measure_methods import (
+ LocalBuilder,
+ LocalRunner,
+ RPCRunner,
+ default_module_loader,
+ request_remote,
+)
from .executor import Executor
from .local_executor import LocalExecutor
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index ffe4b97..62fd811 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -22,11 +22,13 @@ These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
"""
+import contextlib
import logging
import shutil
import os
import threading
import time
+import typing
from random import getrandbits
from collections import namedtuple
import tempfile
@@ -199,6 +201,9 @@ class RPCRunner(Runner):
its actual latency during end-to-end inference.
To make this option effective, the argument `number` should also be set to 1.
This is only has effect on CPU task.
+ module_loader : ModuleLoader
+ If given, a context manager that loads the module to be timed into the remote runtime.
+ If not given, default_module_loader is used.
"""
def __init__(
@@ -214,6 +219,7 @@ class RPCRunner(Runner):
min_repeat_ms=0,
cooldown_interval=0.1,
enable_cpu_cache_flush=False,
+ module_loader=None,
):
super(RPCRunner, self).__init__(timeout, n_parallel)
@@ -229,6 +235,7 @@ class RPCRunner(Runner):
self.enable_cpu_cache_flush = enable_cpu_cache_flush
self.cooldown_interval = cooldown_interval
+ self.module_loader = module_loader
self.executor = LocalExecutor(timeout=timeout * (self.n_parallel + 1))
@@ -280,6 +287,11 @@ class RPCRunner(Runner):
for measure_inp, build_res in zip(
measure_inputs[i : i + self.n_parallel], build_results[i : i + self.n_parallel]
):
+ module_loader = (
+ self.module_loader
+ if self.module_loader is not None
+ else default_module_loader()
+ )
ret = self.executor.submit(
run_through_rpc,
measure_inp,
@@ -290,6 +302,7 @@ class RPCRunner(Runner):
self.cooldown_interval,
remote_args,
self.enable_cpu_cache_flush,
+ module_loader,
)
futures.append(ret)
@@ -352,6 +365,7 @@ class LocalRunner(RPCRunner):
min_repeat_ms=0,
cooldown_interval=0.1,
enable_cpu_cache_flush=False,
+ module_loader=None,
):
super(LocalRunner, self).__init__(
"",
@@ -365,6 +379,7 @@ class LocalRunner(RPCRunner):
min_repeat_ms=min_repeat_ms,
cooldown_interval=cooldown_interval,
enable_cpu_cache_flush=enable_cpu_cache_flush,
+ module_loader=module_loader,
)
self.tracker = None
self.server = None
@@ -473,6 +488,11 @@ class _WrappedBuildFunc:
return BuildResult(filename, arg_info, None, time.time() - tic)
+ModuleLoader = typing.Callable[
+ [dict, dict], typing.ContextManager[typing.Tuple[tvm.rpc.RPCSession, tvm.runtime.Module]]
+]
+
+
def run_through_rpc(
measure_input,
build_result,
@@ -480,8 +500,9 @@ def run_through_rpc(
repeat,
min_repeat_ms,
cooldown_interval,
- remote_args,
+ remote_kwargs,
enable_cpu_cache_flush=False,
+ module_loader=None,
):
"""Run a generated library through rpc
@@ -509,14 +530,16 @@ def run_through_rpc(
will be automatically increased.
cooldown_interval: float
The cool down interval between two measurements
- remote_args: Tuple
- The argument for request_remote
+ remote_kwargs: dict
+ Passed to module_loader(). Ultimately, keyword args to request_remote().
enable_cpu_cache_flush: bool
Whether to flush cache on CPU between repeated measurements.
Flushing cache can make the measured latency of one operator closer to
its actual latency during end-to-end inference.
To make this option effective, the argument `number` should also be set to 1.
This is only has effect on CPU task.
+ module_loader: ModuleLoader
+ A function that returns a ContextManager used to establish and teardown the remote session.
"""
if isinstance(build_result, MeasureResult):
return build_result
@@ -525,55 +548,38 @@ def run_through_rpc(
errno = MeasureErrorNo.NO_ERROR
try:
# upload built module
- remote = request_remote(*remote_args)
- # Program the FPGA every single time when targeting VTA
- if (
- hasattr(measure_input.target, "device_name")
- and measure_input.target.device_name == "vta"
- ):
- # pylint: disable=import-outside-toplevel
- from vta import program_fpga, reconfig_runtime
-
- program_fpga(remote, None)
- reconfig_runtime(remote)
- remote.upload(build_result.filename)
- func = remote.load_module(os.path.split(build_result.filename)[1])
- ctx = remote.context(str(measure_input.target), 0)
-
- # Limitation:
- # We can not get PackFunction directly in the remote mode as it is wrapped
- # under the std::function. We could lift the restriction later once we fold
- # the PackedFunc as an object. Currently, we pass function name to work
- # around it.
- f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
- time_f = func.time_evaluator(
- func.entry_name,
- ctx,
- number=number,
- repeat=repeat,
- min_repeat_ms=min_repeat_ms,
- f_preproc=f_prepare,
- )
-
- try:
- random_fill = remote.get_function("tvm.contrib.random.random_fill")
- except AttributeError:
- raise AttributeError(
- "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
+ with module_loader(remote_kwargs, build_result) as (remote, mod):
+ ctx = remote.context(str(measure_input.target), 0)
+
+ # Limitation:
+ # We can not get PackFunction directly in the remote mode as it is wrapped
+ # under the std::function. We could lift the restriction later once we fold
+ # the PackedFunc as an object. Currently, we pass function name to work
+ # around it.
+ f_prepare = "cache_flush_cpu_non_first_arg" if enable_cpu_cache_flush else ""
+ time_f = mod.time_evaluator(
+ mod.entry_name,
+ ctx,
+ number=number,
+ repeat=repeat,
+ min_repeat_ms=min_repeat_ms,
+ f_preproc=f_prepare,
)
- args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
- if "scatter" not in measure_input.task.name:
- # the index tensor of scatter op cannot be randomly initialized
- for arg in args:
- random_fill(arg)
- ctx.sync()
- costs = time_f(*args).results
+ try:
+ random_fill = remote.get_function("tvm.contrib.random.random_fill")
+ except AttributeError:
+ raise AttributeError(
+ "Please make sure USE_RANDOM is ON in the config.cmake " "on the remote devices"
+ )
+ args = [nd.array(np.zeros(x[0], dtype=x[1]), ctx=ctx) for x in build_result.arg_info]
+ if "scatter" not in measure_input.task.name:
+ # the index tensor of scatter op cannot be randomly initialized
+ for arg in args:
+ random_fill(arg)
+ ctx.sync()
- # clean up remote files
- remote.remove(build_result.filename)
- remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
- remote.remove("")
+ costs = time_f(*args).results
if len(costs) > 2: # remove largest and smallest value to reduce variance
costs = list(costs)
@@ -592,6 +598,40 @@ def run_through_rpc(
return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)
+def default_module_loader(pre_load_function=None):
+ """Returns a default function that can be passed as module_loader to run_through_rpc.
+
+ Parameters
+ ----------
+ pre_load_function : Optional[Function[tvm.rpc.Session, tvm.runtime.Module]]
+ Invoked after a session is established and before the default code-loading RPC calls are
+ issued. Allows performing pre-upload actions, e.g. resetting the remote runtime environment.
+
+ Returns
+ -------
+ ModuleLoader :
+ A function that can be passed as module_loader to run_through_rpc.
+ """
+
+ @contextlib.contextmanager
+ def default_module_loader_mgr(remote_kwargs, build_result):
+ remote = request_remote(**remote_kwargs)
+ if pre_load_function is not None:
+ pre_load_function(remote, build_result)
+
+ remote.upload(build_result.filename)
+ try:
+ yield remote, remote.load_module(os.path.split(build_result.filename)[1])
+
+ finally:
+ # clean up remote files
+ remote.remove(build_result.filename)
+ remote.remove(os.path.splitext(build_result.filename)[0] + ".so")
+ remote.remove("")
+
+ return default_module_loader_mgr
+
+
def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
"""Request a remote session
diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py
index d143c4d..5fce768 100644
--- a/vta/python/vta/__init__.py
+++ b/vta/python/vta/__init__.py
@@ -22,6 +22,7 @@ configure the hardware environment and access remote device through RPC.
"""
import sys
+from .autotvm import module_loader
from .bitstream import get_bitstream_path, download_bitstream
from .environment import get_env, Environment
from .rpc_client import reconfig_runtime, program_fpga
diff --git a/vta/python/vta/autotvm.py b/vta/python/vta/autotvm.py
new file mode 100644
index 0000000..9aa7390
--- /dev/null
+++ b/vta/python/vta/autotvm.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Defines AutoTVM components used with VTA."""
+
+from tvm.autotvm.measure import default_module_loader
+from . import rpc_client
+
+
+def module_loader(bitstream=None):
+ """Construct a ModuleLoader implementation specialized for VTA.
+
+ Parameters
+ ----------
+ bitsream : Optional[str]
+ Path to the bitstream to write prior to uploading code.
+
+ Returns
+ -------
+ ModuleLoader :
+ The ModuleLoader instance.
+ """
+
+ def reprogram_fpga(remote, _build_result):
+ """default_module_loader callback which reprograms the FPGA.
+
+ Parameters
+ ----------
+ remote : tvm.rpc.RPCSession
+ RPC session established to the remote device.
+
+ _build_result : tvm.autotvm.measure.measure_methods.BuildResult
+ Artifact from the build phase, unused here.
+ """
+ rpc_client.program_bitstream(remote, bitstream)
+ rpc_client.reconfig_runtime(remote)
+
+ return default_module_loader(reprogram_fpga)
diff --git a/vta/tutorials/autotvm/tune_relay_vta.py b/vta/tutorials/autotvm/tune_relay_vta.py
index c5885b6..ed2671c 100644
--- a/vta/tutorials/autotvm/tune_relay_vta.py
+++ b/vta/tutorials/autotvm/tune_relay_vta.py
@@ -215,6 +215,7 @@ tuning_option = {
port=tracker_port,
number=5,
timeout=60,
+ module_loader=vta.module_loader(),
# check_correctness=True, # TODO: re-enable when check_correctness works again.
),
),