You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/09/09 09:51:58 UTC

[tvm] branch main updated: [microTVM] Add support for AutoTVM (#8715)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new aa2b37d  [microTVM] Add support for AutoTVM (#8715)
aa2b37d is described below

commit aa2b37d35b29791f80d63403565b936b7208ca55
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Thu Sep 9 11:51:44 2021 +0200

    [microTVM] Add support for AutoTVM (#8715)
    
    * Initial commit of API server impl.
    
    * initial commit of api client
    
    * Add TVM-side glue code to use Project API
    
    * Change tvm.micro.Session to use Project API
    
    * Rework how crt_config.h is used on the host.
    
     * use template crt_config.h for host test runtime; delete
       src/runtime/crt/host/crt_config.h so that it doesn't diverge from
       the template
     * bring template crt_config.h inline with the one actually in use
      * rename to MAX_STRLEN_DLTYPE
     * Create a dedicated TVM-side host crt_config.h in src/runtime/micro
    
    * Modify Transport infrastructure to work with Project API
    
    * Add host microTVM API server
    
    * Zephyr implementation of microTVM API server
    
     * move all zephyr projects to apps/microtvm/zephyr/template_project
    
    * consolidate CcompilerAnnotator
    
    * Allow model library format with c backend, add test.
    
    * Update unit tests
    
    * fix incorrect doc
    
    * Delete old Zephyr build infrastructure
    
    * Delete old build abstractions
    
    * Delete old Transport implementations and simplify module
    
    * lint
    
    * ASF header
    
    * address gromero comments
    
    * final fixes?
    
    * fix is_shutdown
    
    * fix user-facing API
    
    * fix TempDirectory / operator
    
    * Update micro_tflite tutorial
    
    * lint
    
    * fix test_crt and test_link_params
    
    * undo global micro import, hopefully fix fixture
    
    * lint
    
    * fix more tests
    
    * Add session_constructor_args to tracker request() function.
    
     * Allows tracker clients to open non-traditional RPC sessions
    
    * Generate entry_func symbol in C host codegen.
    
     * Needed for AutoTVM.
    
    * print MeasureErrorNo enum value in MeasureResult repr
    
    * Add microTVM session constructor.
    
     * This constructor is to be called from the RPC driver to flash and
       connect to the RPC server on the microcontroller.
    
    * add build_kwargs as a Builder constructor arg.
    
     * build_kwargs is derived from pre-configured args, the runner, and
       now from the script.
     * user-supplied build kwargs override the other two, and a warning is
       printed if any key is overridden.
    
    * Add do_fork option to Builder, to support stateful builders
    
     * When AutoTVM builder forks, any global state modified by the
       build_func is lost between builds
    
    * Checkin module_loader used to build and flash microTVM for autotuning.
    
    * Import micro into top-level when enabled.
    
     * AutoTVM RPC server needs to load the micro session constructor.
    
    * Add tvm.contrib.random.random_fill to microTVM.
    
     * Allows autotuning with random data.
    
    * Move compilation to runner :O
    
    * Add a tutorial for AutoTVM with microcontrollers.
    
    * Fix si_prefix in autotuner callback
    
    * black format and git-clang-format
    
    * Switch tutorial back to qemu version
    
    * improve error reporting so CI will show test error
    
    * black format
    
    * autotvm is working
    
    * fix tutorial
    
    * fix dependencies
    
    * fix auto tune issue
    
    * lint
    
    * address comments
    
    * fix lint
    
    * test crt and zephyr added
    
    * fix func registery size
    
    * moved autotune test and fixed
    
    * fix crt test
    
    * address comments
    
    * change relay text
    
    * change relay in text_zephyr
    
    * class added
    
    * changed relay module in tutorial and cleanup
    
    * address comments
    
    * address TK comments
    
    * change fork
    
    * final comments
    
    * retrigger due to flahy test
    
    * fix tutorial
    
    * retrigger
    
    * fix changes due to merge
    
    Co-authored-by: Andrew Reusch <ar...@octoml.ai>
---
 apps/bundle_deploy/crt_config/crt_config.h         |   2 +-
 apps/microtvm/pyproject.toml                       |   1 +
 include/tvm/runtime/crt/error_codes.h              |   1 +
 python/tvm/__init__.py                             |   3 +
 python/tvm/autotvm/measure/measure.py              |  33 ++-
 python/tvm/autotvm/measure/measure_methods.py      |  27 ++-
 python/tvm/autotvm/tuner/callback.py               |   4 +-
 python/tvm/micro/__init__.py                       |   2 +
 python/tvm/micro/build.py                          |  55 +++++
 python/tvm/micro/project.py                        |  17 +-
 python/tvm/micro/session.py                        |  73 +++++-
 python/tvm/rpc/client.py                           |  17 +-
 python/tvm/support.py                              |   9 +-
 src/runtime/crt/common/crt_runtime_api.c           |  23 ++
 src/runtime/crt/common/ndarray.c                   |  25 ++-
 src/runtime/crt/crt_config-template.h              |   2 +-
 .../tvm/runtime/crt/internal/common/ndarray.h      |   4 +
 src/target/source/codegen_c_host.cc                |  11 +
 tests/micro/zephyr/test_zephyr.py                  | 142 +++++++++++-
 tests/python/unittest/test_crt.py                  | 103 +++++++++
 tests/python/unittest/test_runtime_module_load.py  |   7 +-
 tutorials/micro/micro_autotune.py                  | 250 +++++++++++++++++++++
 22 files changed, 778 insertions(+), 33 deletions(-)

diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h
index 58f9235..b89bedb 100644
--- a/apps/bundle_deploy/crt_config/crt_config.h
+++ b/apps/bundle_deploy/crt_config/crt_config.h
@@ -43,7 +43,7 @@
 #define TVM_CRT_MAX_REGISTERED_MODULES 2
 
 /*! Size of the global function registry, in bytes. */
-#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
+#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512
 
 /*! Maximum packet size, in bytes, including the length header. */
 #define TVM_CRT_MAX_PACKET_SIZE_BYTES 512
diff --git a/apps/microtvm/pyproject.toml b/apps/microtvm/pyproject.toml
index 8bfae0a..98c769b 100644
--- a/apps/microtvm/pyproject.toml
+++ b/apps/microtvm/pyproject.toml
@@ -111,6 +111,7 @@ tensorflow-estimator = {version = "^2.1", optional = true}
 # TFLite frontend
 tflite = {version = "2.1.0", optional = true}
 wheel = "*"
+cloudpickle = "^1.6.0"
 
 
 [tool.poetry.extras]
diff --git a/include/tvm/runtime/crt/error_codes.h b/include/tvm/runtime/crt/error_codes.h
index d1a8619..776691c 100644
--- a/include/tvm/runtime/crt/error_codes.h
+++ b/include/tvm/runtime/crt/error_codes.h
@@ -93,6 +93,7 @@ typedef enum {
   kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0),
   kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
   kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),
+  kTvmErrorFunctionCallInvalidArg = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 3),
 
   // Time Evaluator - times functions for use with debug runtime.
   kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0),
diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py
index 55a2288..57374c5 100644
--- a/python/tvm/__init__.py
+++ b/python/tvm/__init__.py
@@ -68,6 +68,9 @@ from . import support
 # Contrib initializers
 from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
 
+if support.libinfo().get("USE_MICRO", "OFF") == "ON":
+    from . import micro
+
 # NOTE: This file should be python2 compatible so we can
 # raise proper error message when user run the package using
 # an older version of the python
diff --git a/python/tvm/autotvm/measure/measure.py b/python/tvm/autotvm/measure/measure.py
index 8438b80..ea7de35 100644
--- a/python/tvm/autotvm/measure/measure.py
+++ b/python/tvm/autotvm/measure/measure.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
 """User facing API for specifying how to measure the generated code"""
+import enum
 import multiprocessing
 from collections import namedtuple
 
@@ -52,8 +53,19 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
         The absolute time stamp when we finish measurement.
     """
 
+    def __repr__(self):
+        error_no_str = (
+            str(self.error_no)
+            if self.error_no not in MeasureErrorNo
+            else str(MeasureErrorNo(self.error_no))
+        )
+        return (
+            f"{self.__class__.__name__}(costs={self.costs!r}, error_no={error_no_str}, "
+            f"all_cost={self.all_cost}, timestamp={self.timestamp!r})"
+        )
 
-class MeasureErrorNo(object):
+
+class MeasureErrorNo(enum.IntEnum):
     """Error type for MeasureResult"""
 
     NO_ERROR = 0  # no error
@@ -77,12 +89,15 @@ class Builder(object):
     n_parallel: int, optional
         The number of tasks submitted in parallel
         By default it will use all cpu cores
+    build_kwargs: dict, optional
+        Keyword args given to the build function.
     """
 
-    def __init__(self, timeout=10, n_parallel=None):
+    def __init__(self, timeout=10, n_parallel=None, build_kwargs=None):
         self.timeout = timeout
         self.n_parallel = n_parallel or multiprocessing.cpu_count()
-        self.build_kwargs = {}
+        self.user_build_kwargs = build_kwargs if build_kwargs is not None else {}
+        self.runner_build_kwargs = None
         self.task = None
 
     def set_task(self, task, build_kwargs=None):
@@ -97,7 +112,17 @@ class Builder(object):
             The additional kwargs for build function
         """
         self.task = task
-        self.build_kwargs = build_kwargs
+        self.build_kwargs = dict(build_kwargs.items()) if build_kwargs is not None else {}
+        if any(k in self.build_kwargs for k in self.user_build_kwargs):
+            logging.warn(
+                "Overriding these runner-supplied kwargs with user-supplied:\n%s",
+                "\n".join(
+                    f" * {k}: from {build_kwargs[k]!r} to {self.user_build_kwargs[k]!r}"
+                    for k in sorted([k for k in build_kwargs if k in self.user_build_kwargs])
+                ),
+            )
+        for k, v in self.user_build_kwargs.items():
+            self.build_kwargs[k] = v
 
     def build(self, measure_inputs):
         """Build programs
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 42e046a..efe45da 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -79,15 +79,22 @@ class LocalBuilder(Builder):
         The timeout of a compilation
     n_parallel: int
         The number of tasks run in parallel. "None" will use all cpu cores
+    build_kwargs: dict
+        If supplied, additional kwargs passed to build_func. Overrides any build_kwargs supplied
+        by the Runner.
     build_func: callable or str
         If is 'default', use default build function
         If is 'ndk', use function for android ndk
         If id 'stackvm', use function for stackvm
         If is callable, use it as custom build function, expect lib_format field.
+    do_fork: bool
+        If False, do not fork when building. Requires n_parallel=1.
     """
 
-    def __init__(self, timeout=10, n_parallel=None, build_func="default"):
-        super(LocalBuilder, self).__init__(timeout, n_parallel)
+    def __init__(
+        self, timeout=10, n_parallel=None, build_kwargs=None, build_func="default", do_fork=False
+    ):
+        super(LocalBuilder, self).__init__(timeout, n_parallel, build_kwargs)
 
         if isinstance(build_func, str):
             if build_func == "default":
@@ -99,6 +106,11 @@ class LocalBuilder(Builder):
             else:
                 raise ValueError("Invalid build_func" + build_func)
         self.build_func = _WrappedBuildFunc(build_func)
+        if not do_fork:
+            assert n_parallel in (
+                None,
+                1,
+            ), f"if do_fork=False, need n_parallel=None or 1; got {n_parallel}"
         self.executor = PopenPoolExecutor(
             timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,)
         )
@@ -518,7 +530,16 @@ class _WrappedBuildFunc:
             )
             # TODO(tvm-team) consider linline _build_func_common
             func, arg_info = _build_func_common(measure_input, **kwargs)
-            func.export_library(filename, self.build_func)
+            if self.build_func.output_format == ".model-library-format":
+                # Late import to preserve autoTVM with USE_MICRO OFF
+                try:
+                    from tvm import micro  # pylint: disable=import-outside-toplevel
+                except ImportError:
+                    raise ImportError("Requires USE_MICRO")
+
+                micro.export_model_library_format(func, filename)
+            else:
+                func.export_library(filename, self.build_func)
         except Exception as e:  # pylint: disable=broad-except
             return BuildResult(None, None, e, time.time() - tic)
         return BuildResult(filename, arg_info, None, time.time() - tic)
diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py
index dc75de2..40ee24e 100644
--- a/python/tvm/autotvm/tuner/callback.py
+++ b/python/tvm/autotvm/tuner/callback.py
@@ -145,8 +145,8 @@ def progress_bar(total, prefix="", si_prefix="G"):
 
     if logger.level < logging.DEBUG:  # only print progress bar in non-debug mode
         sys.stdout.write(
-            "\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) "
-            "| %.2f s" % (prefix, 0, 0, 0, total, time.time() - tic)
+            "\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
+            "| %.2f s" % (prefix, 0, 0, si_prefix, 0, total, time.time() - tic)
         )
         sys.stdout.flush()
 
diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py
index 88dcde8..2aea9d3 100644
--- a/python/tvm/micro/__init__.py
+++ b/python/tvm/micro/__init__.py
@@ -16,6 +16,8 @@
 # under the License.
 """MicroTVM module for bare-metal backends"""
 
+from .build import autotvm_build_func
+from .build import AutoTvmModuleLoader
 from .build import get_standalone_crt_dir
 from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError
 from .project import generate_project, GeneratedProject, TemplateProject
diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py
index 16e7ed2..7da9daf 100644
--- a/python/tvm/micro/build.py
+++ b/python/tvm/micro/build.py
@@ -17,10 +17,13 @@
 
 """Defines top-level glue functions for building microTVM artifacts."""
 
+import json
 import logging
 import os
+import pathlib
 
 from .._ffi import libinfo
+from .. import rpc as _rpc
 
 
 _LOG = logging.getLogger(__name__)
@@ -57,3 +60,55 @@ def get_standalone_crt_dir() -> str:
             raise CrtNotFoundError()
 
     return STANDALONE_CRT_DIR
+
+
+class AutoTvmModuleLoader:
+    """MicroTVM AutoTVM Module Loader
+
+    Parameters
+    ----------
+    template_project_dir : str
+        project template path
+
+    project_options : dict
+        project generation option
+    """
+
+    def __init__(self, template_project_dir: str, project_options: dict = None):
+        self._project_options = project_options
+
+        if isinstance(template_project_dir, pathlib.Path):
+            self._template_project_dir = str(template_project_dir)
+        elif not isinstance(template_project_dir, str):
+            raise TypeError(f"Incorrect type {type(template_project_dir)}.")
+
+    def __call__(self, remote_kw, build_result):
+        with open(build_result.filename, "rb") as build_file:
+            build_result_bin = build_file.read()
+
+        tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"])
+        remote = tracker.request(
+            remote_kw["device_key"],
+            priority=remote_kw["priority"],
+            session_timeout=remote_kw["timeout"],
+            session_constructor_args=[
+                "tvm.micro.compile_and_create_micro_session",
+                build_result_bin,
+                self._template_project_dir,
+                json.dumps(self._project_options),
+            ],
+        )
+        system_lib = remote.get_function("runtime.SystemLib")()
+        yield remote, system_lib
+        try:
+            remote.get_function("tvm.micro.destroy_micro_session")()
+        except tvm.error.TVMError as exception:
+            _LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)
+
+
+def autotvm_build_func():
+    """A dummy build function which causes autotvm to use a different export format."""
+
+
+# A sentinel value for the output format.
+autotvm_build_func.output_format = ".model-library-format"
diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py
index b1f2b49..8a62c9b 100644
--- a/python/tvm/micro/project.py
+++ b/python/tvm/micro/project.py
@@ -101,14 +101,9 @@ class TemplateProject:
         if not self._info["is_template"]:
             raise NotATemplateProjectError()
 
-    def generate_project(self, graph_executor_factory, project_dir, options):
-        """Generate a project given GraphRuntimeFactory."""
-        model_library_dir = utils.tempdir()
-        model_library_format_path = model_library_dir.relpath("model.tar")
-        export_model_library_format(graph_executor_factory, model_library_format_path)
-
+    def generate_project_from_mlf(self, model_library_format_path, project_dir, options):
         self._api_client.generate_project(
-            model_library_format_path=model_library_format_path,
+            model_library_format_path=str(model_library_format_path),
             standalone_crt_dir=get_standalone_crt_dir(),
             project_dir=project_dir,
             options=options,
@@ -119,6 +114,14 @@ class TemplateProject:
     def info(self):
         return self._info
 
+    def generate_project(self, graph_executor_factory, project_dir, options):
+        """Generate a project given GraphRuntimeFactory."""
+        model_library_dir = utils.tempdir()
+        model_library_format_path = model_library_dir.relpath("model.tar")
+        export_model_library_format(graph_executor_factory, model_library_format_path)
+
+        return self.generate_project_from_mlf(model_library_format_path, project_dir, options)
+
 
 def generate_project(
     template_project_dir: typing.Union[pathlib.Path, str],
diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py
index d4ad5b8..abe7aff 100644
--- a/python/tvm/micro/session.py
+++ b/python/tvm/micro/session.py
@@ -17,14 +17,17 @@
 
 """Defines a top-level glue class that operates the Transport and Flasher classes."""
 
+import json
 import logging
 import sys
 
 from ..error import register_error
-from .._ffi import get_global_func
+from .._ffi import get_global_func, register_func
 from ..contrib import graph_executor
+from ..contrib import utils
 from ..contrib.debugger import debug_executor
 from ..rpc import RPCSession
+from . import project
 from .transport import IoTimeoutError
 from .transport import TransportLogger
 
@@ -234,3 +237,71 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
         graph_json_str,
         dump_root=dump_root,
     )
+
+
+RPC_SESSION = None
+
+
+@register_func("tvm.micro.compile_and_create_micro_session")
+def compile_and_create_micro_session(
+    mod_src_bytes: bytes,
+    template_project_dir: str,
+    project_options: dict = None,
+):
+    """Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session.
+
+    Parameters
+    ----------
+    mod_src_bytes : bytes
+        The content of a tarfile which contains the TVM-generated sources which together form the
+        SystemLib. This tar is expected to be created by export_library. The tar will be extracted
+        into a directory and the sources compiled into a MicroLibrary using the Compiler.
+
+    template_project_dir: str
+        The path to a template microTVM Project API project which is used to generate the embedded
+        project that is built and flashed onto the target device.
+
+    project_options: dict
+        Options for the microTVM API Server contained in template_project_dir.
+    """
+    global RPC_SESSION
+
+    temp_dir = utils.tempdir()
+    # Keep temp directory for generate project
+    temp_dir.set_keep_for_debug(True)
+    model_library_format_path = temp_dir / "model.tar.gz"
+    with open(model_library_format_path, "wb") as mlf_f:
+        mlf_f.write(mod_src_bytes)
+
+    try:
+        template_project = project.TemplateProject.from_directory(template_project_dir)
+        generated_project = template_project.generate_project_from_mlf(
+            model_library_format_path,
+            temp_dir / "generated-project",
+            options=json.loads(project_options),
+        )
+    except Exception as exception:
+        logging.error("Project Generate Error: %s", str(exception))
+        raise exception
+
+    generated_project.build()
+    generated_project.flash()
+    transport = generated_project.transport()
+
+    RPC_SESSION = Session(transport_context_manager=transport)
+    RPC_SESSION.__enter__()
+    return RPC_SESSION._rpc._sess
+
+
+@register_func
+def destroy_micro_session():
+    """Destroy RPC session for microTVM autotune."""
+    global RPC_SESSION
+
+    if RPC_SESSION is not None:
+        exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
+        RPC_SESSION = None
+        if (exc_type, exc_value, traceback) != (None, None, None):
+            exc = exc_type(exc_value)  # See PEP 3109
+            exc.__traceback__ = traceback
+            raise exc
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index a983439..045bf79 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -366,7 +366,9 @@ class TrackerSession(object):
         res += separate_line
         return res
 
-    def request(self, key, priority=1, session_timeout=0, max_retry=5):
+    def request(
+        self, key, priority=1, session_timeout=0, max_retry=5, session_constructor_args=None
+    ):
         """Request a new connection from the tracker.
 
         Parameters
@@ -384,6 +386,11 @@ class TrackerSession(object):
 
         max_retry : int, optional
             Maximum number of times to retry before give up.
+
+        session_constructor_args : list, optional
+            List of additional arguments to passed as the remote session constructor.
+            The first element of the list is always a string specifying the name of
+            the session constructor, the following args are the positional args to that function.
         """
         last_err = None
         for _ in range(max_retry):
@@ -395,7 +402,13 @@ class TrackerSession(object):
                 if value[0] != base.TrackerCode.SUCCESS:
                     raise RuntimeError("Invalid return value %s" % str(value))
                 url, port, matchkey = value[1]
-                return connect(url, port, matchkey, session_timeout)
+                return connect(
+                    url,
+                    port,
+                    matchkey,
+                    session_timeout,
+                    session_constructor_args=session_constructor_args,
+                )
             except socket.error as err:
                 self.close()
                 last_err = err
diff --git a/python/tvm/support.py b/python/tvm/support.py
index 800bfe4..1adbee0 100644
--- a/python/tvm/support.py
+++ b/python/tvm/support.py
@@ -29,7 +29,14 @@ def libinfo():
     info: Dict[str, str]
         The dictionary of compile-time info.
     """
-    return {k: v for k, v in GetLibInfo().items()}  # pylint: disable=unnecessary-comprehension
+    get_lib_info_func = get_global_func("support.GetLibInfo", allow_missing=True)
+    if get_lib_info_func is not None:
+        lib_info = get_lib_info_func()
+        if lib_info is None:
+            return {}
+    else:
+        return {}
+    return dict(lib_info.items())
 
 
 class FrontendTestModule(Module):
diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c
index 04721ee..ea986a3 100644
--- a/src/runtime/crt/common/crt_runtime_api.c
+++ b/src/runtime/crt/common/crt_runtime_api.c
@@ -395,6 +395,8 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal
   return 0;
 }
 
+int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
+                         int* ret_type_code);
 tvm_crt_error_t TVMInitializeRuntime() {
   int idx = 0;
   tvm_crt_error_t error = kTvmErrorNoError;
@@ -432,6 +434,10 @@ tvm_crt_error_t TVMInitializeRuntime() {
     error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0);
   }
 
+  if (error == kTvmErrorNoError) {
+    error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &TVMContribRandomFill, 0);
+  }
+
   if (error != kTvmErrorNoError) {
     TVMPlatformMemoryFree(registry_backing_memory, dev);
   }
@@ -563,3 +569,20 @@ release_and_return : {
 __attribute__((weak)) tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
   return kTvmErrorFunctionCallNotImplemented;
 }
+
+// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom.
+// Named to correspond with the analogous function in the C++ runtime.
+int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
+                         int* ret_type_code) {
+  if (num_args != 1) {
+    return kTvmErrorFunctionCallNumArguments;
+  }
+
+  if (type_codes[0] != kTVMDLTensorHandle) {
+    return kTvmErrorFunctionCallWrongArgType;
+  }
+
+  DLTensor* tensor = (DLTensor*)args[0].v_handle;
+  TVMNDArray arr = {*tensor};
+  return TVMNDArray_RandomFill(&arr);
+}
diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c
index c97f765..16bde32 100644
--- a/src/runtime/crt/common/ndarray.c
+++ b/src/runtime/crt/common/ndarray.c
@@ -47,18 +47,22 @@ int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype,
   return 0;
 }
 
+int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array) {
+  int64_t num_elems = 1;
+  int32_t idx;
+  for (idx = 0; idx < array->dl_tensor.ndim; ++idx) {
+    num_elems *= array->dl_tensor.shape[idx];
+  }
+  return (num_elems * array->dl_tensor.dtype.bits + 7) / 8;
+}
+
 int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
                      TVMNDArray* array) {
   int status = TVMNDArray_Create(ndim, shape, dtype, dev, array);
   if (status != 0) {
     return status;
   }
-  int64_t num_elems = 1;
-  int32_t idx;
-  for (idx = 0; idx < array->dl_tensor.ndim; ++idx) {
-    num_elems *= shape[idx];
-  }
-  int total_elem_bytes = (num_elems * dtype.bits + 7) / 8;
+  int total_elem_bytes = TVMNDArray_DataSizeBytes(array);
   array->dl_tensor.data =
       TVMBackendAllocWorkspace(kDLCPU, 0, total_elem_bytes, dtype.code, dtype.bits);
   memset(array->dl_tensor.data, 0, total_elem_bytes);
@@ -136,6 +140,15 @@ int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndi
   return 0;
 }
 
+int TVMNDArray_RandomFill(TVMNDArray* arr) {
+  int64_t num_bytes = TVMNDArray_DataSizeBytes(arr);
+  if (num_bytes < 0 || num_bytes > SIZE_MAX) {
+    return kTvmErrorFunctionCallInvalidArg;
+  }
+
+  return TVMPlatformGenerateRandom(arr->dl_tensor.data, (size_t)num_bytes);
+}
+
 int TVMNDArray_Release(TVMNDArray* arr) {
   tvm_crt_error_t err;
   DLDevice dev = {kDLCPU, 0};
diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h
index 7949aea..aa718a3 100644
--- a/src/runtime/crt/crt_config-template.h
+++ b/src/runtime/crt/crt_config-template.h
@@ -37,7 +37,7 @@
 #define TVM_CRT_MAX_ARGS 10
 
 /*! Size of the global function registry, in bytes. */
-#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 250
+#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512
 
 /*! Maximum number of registered modules. */
 #define TVM_CRT_MAX_REGISTERED_MODULES 2
diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
index f878477..e5869ed 100644
--- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
+++ b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h
@@ -44,6 +44,10 @@ typedef struct TVMNDArray {
 int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
                       TVMNDArray* array);
 
+int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array);
+
+int TVMNDArray_RandomFill(TVMNDArray* array);
+
 int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev,
                      TVMNDArray* array);
 
diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc
index dc849b8..80ace92 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -59,6 +59,17 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
   function_names_.push_back(global_symbol.value());
 
   CodeGenC::AddFunction(f);
+  if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+    function_names_.push_back(runtime::symbol::tvm_module_main);
+    stream << "// CodegenC: NOTE: Auto-generated entry function\n";
+    PrintFuncPrefix();
+    stream << " " << tvm::runtime::symbol::tvm_module_main
+           << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, "
+           << "int* out_ret_tcode, void* resource_handle) {\n";
+    stream << "  return " << global_symbol.value()
+           << "(args, arg_type_ids, num_args, out_ret_value, out_ret_tcode, resource_handle);\n";
+    stream << "}\n";
+  }
 }
 
 void CodeGenCHost::DeclareParameters(Map<String, LinkedParam> params) {
diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py
index 5a7e69e..6085318 100644
--- a/tests/micro/zephyr/test_zephyr.py
+++ b/tests/micro/zephyr/test_zephyr.py
@@ -15,10 +15,6 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import contextlib
-import copy
-import datetime
-import glob
 import logging
 import os
 import pathlib
@@ -401,5 +397,143 @@ def test_rpc_large_array(temp_dir, platform, west_cmd, tvm_debug, shape):
         test_tensors(sess)
 
 
+@tvm.testing.requires_micro
+def test_autotune_conv2d(temp_dir, platform, west_cmd, tvm_debug):
+    """Test AutoTune for microTVM Zephyr"""
+    import tvm.relay as relay
+
+    model, zephyr_board = PLATFORMS[platform]
+
+    # Create a Relay model
+    data_shape = (1, 3, 16, 16)
+    weight_shape = (8, 3, 5, 5)
+    data = relay.var("data", relay.TensorType(data_shape, "float32"))
+    weight = relay.var("weight", relay.TensorType(weight_shape, "float32"))
+    y = relay.nn.conv2d(
+        data,
+        weight,
+        padding=(2, 2),
+        kernel_size=(5, 5),
+        kernel_layout="OIHW",
+        out_dtype="float32",
+    )
+    f = relay.Function([data, weight], y)
+    mod = tvm.IRModule.from_expr(f)
+    mod = relay.transform.InferType()(mod)
+
+    data_sample = np.random.rand(data_shape[0], data_shape[1], data_shape[2], data_shape[3]).astype(
+        "float32"
+    )
+    weight_sample = np.random.rand(
+        weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3]
+    ).astype("float32")
+    params = {mod["main"].params[1].name_hint: weight_sample}
+
+    target = tvm.target.target.micro(model)
+    pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True})
+    with pass_context:
+        tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target)
+    assert len(tasks) > 0
+
+    repo_root = pathlib.Path(
+        subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip()
+    )
+    template_project_dir = repo_root / "apps" / "microtvm" / "zephyr" / "template_project"
+    module_loader = tvm.micro.AutoTvmModuleLoader(
+        template_project_dir=template_project_dir,
+        project_options={
+            "zephyr_board": zephyr_board,
+            "west_cmd": west_cmd,
+            "verbose": 1,
+            "project_type": "host_driven",
+        },
+    )
+    builder = tvm.autotvm.LocalBuilder(
+        n_parallel=1,
+        build_kwargs={"build_option": {"tir.disable_vectorize": True}},
+        do_fork=True,
+        build_func=tvm.micro.autotvm_build_func,
+    )
+    runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
+
+    measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)
+
+    log_path = pathlib.Path("zephyr_autotune.log")
+    if log_path.exists():
+        log_path.unlink()
+
+    n_trial = 10
+    for task in tasks:
+        tuner = tvm.autotvm.tuner.GATuner(task)
+        tuner.tune(
+            n_trial=n_trial,
+            measure_option=measure_option,
+            callbacks=[
+                tvm.autotvm.callback.log_to_file(str(log_path)),
+                tvm.autotvm.callback.progress_bar(n_trial, si_prefix="M"),
+            ],
+            si_prefix="M",
+        )
+
+    # Build without tuning
+    with pass_context:
+        lowered = tvm.relay.build(mod, target=target, params=params)
+
+    temp_dir = utils.tempdir()
+    project = tvm.micro.generate_project(
+        str(template_project_dir),
+        lowered,
+        temp_dir / "project",
+        {
+            "zephyr_board": zephyr_board,
+            "west_cmd": west_cmd,
+            "verbose": 1,
+            "project_type": "host_driven",
+        },
+    )
+    project.build()
+    project.flash()
+
+    with tvm.micro.Session(project.transport()) as session:
+        graph_mod = tvm.micro.create_local_graph_executor(
+            lowered.get_graph_json(), session.get_system_lib(), session.device
+        )
+        graph_mod.set_input(**lowered.get_params())
+        graph_mod.run(data=data_sample)
+        expected_output = graph_mod.get_output(0).numpy()
+        del graph_mod
+
+    # Build using autotune logs
+    with tvm.autotvm.apply_history_best(str(log_path)):
+        with pass_context:
+            lowered_tuned = tvm.relay.build(mod, target=target, params=params)
+
+    temp_dir = utils.tempdir()
+    project = tvm.micro.generate_project(
+        str(template_project_dir),
+        lowered_tuned,
+        temp_dir / "project",
+        {
+            "zephyr_board": zephyr_board,
+            "west_cmd": west_cmd,
+            "verbose": 1,
+            "project_type": "host_driven",
+        },
+    )
+    project.build()
+    project.flash()
+
+    with tvm.micro.Session(project.transport()) as session:
+        graph_mod = tvm.micro.create_local_graph_executor(
+            lowered_tuned.get_graph_json(), session.get_system_lib(), session.device
+        )
+        graph_mod.set_input(**lowered_tuned.get_params())
+        graph_mod.run(data=data_sample)
+        output = graph_mod.get_output(0).numpy()
+        del graph_mod
+
+    tvm.testing.assert_allclose(output, expected_output, rtol=1e-4, atol=1e-5)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py
index 586e9fb..5c6eb92 100644
--- a/tests/python/unittest/test_crt.py
+++ b/tests/python/unittest/test_crt.py
@@ -219,5 +219,108 @@ def test_platform_timer():
         assert len(result.results) == 3
 
 
+@tvm.testing.requires_micro
+def test_autotune():
+    """Verify that autotune works with micro."""
+    import tvm.relay as relay
+
+    data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32"))
+    weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32"))
+    y = relay.nn.conv2d(
+        data,
+        weight,
+        padding=(2, 2),
+        kernel_size=(5, 5),
+        kernel_layout="OIHW",
+        out_dtype="float32",
+    )
+    f = relay.Function([data, weight], y)
+    mod = tvm.IRModule.from_expr(f)
+    mod = relay.transform.InferType()(mod)
+
+    main_func = mod["main"]
+    shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params}
+    type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params}
+
+    weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"])
+    input_data = np.ones(shape_dict["data"]).astype(type_dict["data"])
+    params = {"weight": weight_data}
+    inputs = {"data": input_data}
+
+    target = tvm.target.target.micro("host")
+    template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host")
+
+    pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True})
+    with pass_context:
+        tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target)
+    assert len(tasks) > 0
+
+    module_loader = tvm.micro.AutoTvmModuleLoader(
+        template_project_dir=template_project_dir,
+        project_options={},
+    )
+    builder = tvm.autotvm.LocalBuilder(
+        n_parallel=1,
+        build_kwargs={"build_option": {"tir.disable_vectorize": True}},
+        do_fork=True,
+        build_func=tvm.micro.autotvm_build_func,
+    )
+    runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
+
+    measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)
+
+    tune_log_file = pathlib.Path("crt_autotune.log")
+    if tune_log_file.exists():
+        tune_log_file.unlink()
+
+    num_trials = 10
+    for task in tasks:
+        tuner = tvm.autotvm.tuner.GATuner(task)
+        tuner.tune(
+            n_trial=num_trials,
+            measure_option=measure_option,
+            callbacks=[
+                tvm.autotvm.callback.log_to_file(str(tune_log_file)),
+                tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
+            ],
+            si_prefix="M",
+        )
+
+    # Build without tuning
+    with pass_context:
+        lowered = tvm.relay.build(mod, target=TARGET, params=params)
+
+    temp_dir = tvm.contrib.utils.tempdir()
+    project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project")
+    project.build()
+    with tvm.micro.Session(project.transport()) as session:
+        graph_mod = tvm.micro.create_local_graph_executor(
+            lowered.get_graph_json(), session.get_system_lib(), session.device
+        )
+        graph_mod.set_input(**lowered.get_params())
+        graph_mod.run(**inputs)
+        expected_output = graph_mod.get_output(0).numpy()
+        del graph_mod
+
+    # Build using autotune logs
+    with tvm.autotvm.apply_history_best(str(tune_log_file)):
+        with pass_context:
+            lowered_tuned = tvm.relay.build(mod, target=target, params=params)
+
+    temp_dir = tvm.contrib.utils.tempdir()
+    project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project")
+    project.build()
+    with tvm.micro.Session(project.transport()) as session:
+        graph_mod = tvm.micro.create_local_graph_executor(
+            lowered_tuned.get_graph_json(), session.get_system_lib(), session.device
+        )
+        graph_mod.set_input(**lowered_tuned.get_params())
+        graph_mod.run(**inputs)
+        output = graph_mod.get_output(0).numpy()
+        del graph_mod
+
+    tvm.testing.assert_allclose(output, expected_output, rtol=1e-4, atol=1e-5)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([__file__] + sys.argv[1:]))
diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py
index 5230654..7bf4d72 100644
--- a/tests/python/unittest/test_runtime_module_load.py
+++ b/tests/python/unittest/test_runtime_module_load.py
@@ -88,7 +88,12 @@ def test_dso_module_load():
     with open(path_runtime_py, "w") as fo:
         fo.write(runtime_py)
 
-    subprocess.check_call("python3 %s %s %s" % (path_runtime_py, path_dso, dtype), shell=True)
+    proc = subprocess.run(
+        [sys.executable, path_runtime_py, path_dso, dtype],
+        stdout=subprocess.PIPE,
+        stderr=subprocess.STDOUT,
+    )
+    assert proc.returncode == 0, f"{proc.args} exited with {proc.returncode}: {proc.stdout}"
 
 
 @tvm.testing.requires_gpu
diff --git a/tutorials/micro/micro_autotune.py b/tutorials/micro/micro_autotune.py
new file mode 100644
index 0000000..136bcfe
--- /dev/null
+++ b/tutorials/micro/micro_autotune.py
@@ -0,0 +1,250 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+.. _tutorial-micro-autotune:
+
+Autotuning with micro TVM
+=========================
+**Author**: `Andrew Reusch <https://github.com/areusch>`_, `Mehrdad Hessar <https://github.com/mehrdadh>`
+
+This tutorial explains how to autotune a model using the C runtime.
+"""
+
+import numpy as np
+import subprocess
+import pathlib
+
+import tvm
+
+####################
+# Defining the model
+####################
+#
+# To begin with, define a model in Relay to be executed on-device. Then create an IRModule from relay model and
+# fill parameters with random numbers.
+#
+
+data_shape = (1, 3, 10, 10)
+weight_shape = (6, 3, 5, 5)
+
+data = tvm.relay.var("data", tvm.relay.TensorType(data_shape, "float32"))
+weight = tvm.relay.var("weight", tvm.relay.TensorType(weight_shape, "float32"))
+
+y = tvm.relay.nn.conv2d(
+    data,
+    weight,
+    padding=(2, 2),
+    kernel_size=(5, 5),
+    kernel_layout="OIHW",
+    out_dtype="float32",
+)
+f = tvm.relay.Function([data, weight], y)
+
+relay_mod = tvm.IRModule.from_expr(f)
+relay_mod = tvm.relay.transform.InferType()(relay_mod)
+
+weight_sample = np.random.rand(
+    weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3]
+).astype("float32")
+params = {"weight": weight_sample}
+
+#######################
+# Defining the target #
+#######################
+# Now we define the TVM target that describes the execution environment. This looks very similar
+# to target definitions from other microTVM tutorials.
+#
+# When running on physical hardware, choose a target and a board that
+# describe the hardware. There are multiple hardware targets that could be selected from
+# PLATFORM list in this tutorial. You can chose the platform by passing --platform argument when running
+# this tutorial.
+#
+TARGET = tvm.target.target.micro("host")
+
+# Compiling for physical hardware
+# --------------------------------------------------------------------------
+#  When running on physical hardware, choose a TARGET and a BOARD that describe the hardware. The
+#  STM32L4R5ZI Nucleo target and board is chosen in the example below.
+#
+#    TARGET = tvm.target.target.micro("stm32l4r5zi")
+#    BOARD = "nucleo_l4r5zi"
+
+#########################
+# Extracting tuning tasks
+#########################
+# Not all operators in the Relay program printed above can be tuned. Some are so trivial that only
+# a single implementation is defined; others don't make sense as tuning tasks. Using
+# `extract_from_program`, you can produce a list of tunable tasks.
+#
+# Because task extraction involves running the compiler, we first configure the compiler's
+# transformation passes; we'll apply the same configuration later on during autotuning.
+
+pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True})
+with pass_context:
+    tasks = tvm.autotvm.task.extract_from_program(relay_mod["main"], {}, TARGET)
+assert len(tasks) > 0
+
+######################
+# Configuring microTVM
+######################
+# Before autotuning, we need to define a module loader and then pass that to
+# a `tvm.autotvm.LocalBuilder`. Then we create a `tvm.autotvm.LocalRunner` and use
+# both builder and runner to generates multiple measurements for auto tunner.
+#
+# In this tutorial, we have the option to use x86 host as an example or use different targets
+# from Zephyr RTOS. If you choose pass `--platform=host` to this tutorial it will uses x86. You can
+# choose other options by choosing from `PLATFORM` list.
+#
+
+repo_root = pathlib.Path(
+    subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip()
+)
+
+module_loader = tvm.micro.AutoTvmModuleLoader(
+    template_project_dir=repo_root / "src" / "runtime" / "crt" / "host",
+    project_options={},
+)
+builder = tvm.autotvm.LocalBuilder(
+    n_parallel=1,
+    build_kwargs={"build_option": {"tir.disable_vectorize": True}},
+    do_fork=True,
+    build_func=tvm.micro.autotvm_build_func,
+)
+runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
+
+measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)
+
+# Compiling for physical hardware
+# --------------------------------------------------------------------------
+#    module_loader = tvm.micro.AutoTvmModuleLoader(
+#        template_project_dir=repo_root / "apps" / "microtvm" / "zephyr" / "template_project",
+#        project_options={
+#            "zephyr_board": BOARD,
+#            "west_cmd": "west",
+#            "verbose": 1,
+#            "project_type": "host_driven",
+#        },
+#    )
+#    builder = tvm.autotvm.LocalBuilder(
+#        n_parallel=1,
+#        build_kwargs={"build_option": {"tir.disable_vectorize": True}},
+#        do_fork=False,
+#        build_func=tvm.micro.autotvm_build_func,
+#    )
+#    runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader)
+
+# measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner)
+
+################
+# Run Autotuning
+################
+# Now we can run autotuning separately on each extracted task.
+
+num_trials = 10
+for task in tasks:
+    tuner = tvm.autotvm.tuner.GATuner(task)
+    tuner.tune(
+        n_trial=num_trials,
+        measure_option=measure_option,
+        callbacks=[
+            tvm.autotvm.callback.log_to_file("microtvm_autotune.log"),
+            tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"),
+        ],
+        si_prefix="M",
+    )
+
+############################
+# Timing the untuned program
+############################
+# For comparison, let's compile and run the graph without imposing any autotuning schedules. TVM
+# will select a randomly-tuned implementation for each operator, which should not perform as well as
+# the tuned operator.
+
+with pass_context:
+    lowered = tvm.relay.build(relay_mod, target=TARGET, params=params)
+
+temp_dir = tvm.contrib.utils.tempdir()
+
+project = tvm.micro.generate_project(
+    str(repo_root / "src" / "runtime" / "crt" / "host"), lowered, temp_dir / "project"
+)
+
+# Compiling for physical hardware
+# --------------------------------------------------------------------------
+#    project = tvm.micro.generate_project(
+#        str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"),
+#        lowered,
+#        temp_dir / "project",
+#        {
+#            "zephyr_board": BOARD,
+#            "west_cmd": "west",
+#            "verbose": 1,
+#            "project_type": "host_driven",
+#        },
+#    )
+
+project.build()
+project.flash()
+with tvm.micro.Session(project.transport()) as session:
+    debug_module = tvm.micro.create_local_debug_executor(
+        lowered.get_graph_json(), session.get_system_lib(), session.device
+    )
+    debug_module.set_input(**lowered.get_params())
+    print("########## Build without Autotuning ##########")
+    debug_module.run()
+    del debug_module
+
+##########################
+# Timing the tuned program
+##########################
+# Once autotuning completes, you can time execution of the entire program using the Debug Runtime:
+
+with tvm.autotvm.apply_history_best("microtvm_autotune.log"):
+    with pass_context:
+        lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params)
+
+temp_dir = tvm.contrib.utils.tempdir()
+
+project = tvm.micro.generate_project(
+    str(repo_root / "src" / "runtime" / "crt" / "host"), lowered_tuned, temp_dir / "project"
+)
+
+# Compiling for physical hardware
+# --------------------------------------------------------------------------
+#    project = tvm.micro.generate_project(
+#        str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"),
+#        lowered_tuned,
+#        temp_dir / "project",
+#        {
+#            "zephyr_board": BOARD,
+#            "west_cmd": "west",
+#            "verbose": 1,
+#            "project_type": "host_driven",
+#        },
+#    )
+
+project.build()
+project.flash()
+with tvm.micro.Session(project.transport()) as session:
+    debug_module = tvm.micro.create_local_debug_executor(
+        lowered_tuned.get_graph_json(), session.get_system_lib(), session.device
+    )
+    debug_module.set_input(**lowered_tuned.get_params())
+    print("########## Build with Autotuning ##########")
+    debug_module.run()
+    del debug_module