You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2023/01/19 23:33:03 UTC

[tvm] branch main updated: [microTVM] Add tutorial on how to generate MLPerfTiny submissions (#13783)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cfa65b26c1 [microTVM] Add tutorial on how to generate MLPerfTiny submissions (#13783)
cfa65b26c1 is described below

commit cfa65b26c1bd975daaef78c60b16989be0d23970
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Thu Jan 19 15:32:51 2023 -0800

    [microTVM] Add tutorial on how to generate MLPerfTiny submissions (#13783)
    
    This PR adds a tutorial on how to generate an MLPerftiny submission on Zephyr OS using microTVM.
---
 docs/conf.py                                       |   5 +-
 .../how_to/work_with_microtvm/micro_mlperftiny.py  | 312 +++++++++++++++++++++
 python/tvm/micro/testing/utils.py                  |  44 ++-
 tests/micro/zephyr/utils.py                        |  37 +--
 tests/scripts/request_hook/request_hook.py         |   1 +
 tests/scripts/task_python_microtvm.sh              |   7 +
 6 files changed, 368 insertions(+), 38 deletions(-)

diff --git a/docs/conf.py b/docs/conf.py
index 08fbedb8ff..eb2b39d4b1 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -550,6 +550,9 @@ def force_gc(gallery_conf, fname):
     gc.collect()
 
 
+# Skips certain files to avoid dependency issues
+filename_pattern_default = "^(?!.*micro_mlperftiny.py).*$"
+
 sphinx_gallery_conf = {
     "backreferences_dir": "gen_modules/backreferences",
     "doc_module": ("tvm", "numpy"),
@@ -562,7 +565,7 @@ sphinx_gallery_conf = {
     "within_subsection_order": WithinSubsectionOrder,
     "gallery_dirs": gallery_dirs,
     "subsection_order": subsection_order,
-    "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", ".py"),
+    "filename_pattern": os.environ.get("TVM_TUTORIAL_EXEC_PATTERN", filename_pattern_default),
     "download_all_examples": False,
     "min_reported_time": 60,
     "expected_failing_examples": [],
diff --git a/gallery/how_to/work_with_microtvm/micro_mlperftiny.py b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py
new file mode 100644
index 0000000000..79308e0723
--- /dev/null
+++ b/gallery/how_to/work_with_microtvm/micro_mlperftiny.py
@@ -0,0 +1,312 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+.. _tutorial-micro-MLPerfTiny:
+
+Creating Your MLPerfTiny Submission with microTVM
+=================================================
+**Authors**:
+`Mehrdad Hessar <https://github.com/mehrdadh>`_
+
+This tutorial is showcasing building an MLPerfTiny submission using microTVM. This
+tutorial shows the steps to import a TFLite model from MLPerfTiny benchmark models,
+compile it with TVM and generate a Zephyr project which can be flashed to a Zephyr
+supported board to benchmark the model using EEMBC runner.
+"""
+
+######################################################################
+#
+#     .. include:: ../../../../gallery/how_to/work_with_microtvm/install_dependencies.rst
+#
+
+import os
+import pathlib
+import tarfile
+import tempfile
+import shutil
+
+######################################################################
+#
+#     .. include:: ../../../../gallery/how_to/work_with_microtvm/install_zephyr.rst
+#
+
+
+######################################################################
+#
+# **Note:** Install CMSIS-NN only if you are interested to generate this submission
+# using CMSIS-NN code generator.
+#
+
+######################################################################
+#
+#     .. include:: ../../../../gallery/how_to/work_with_microtvm/install_cmsis.rst
+#
+
+######################################################################
+# Import Python dependencies
+# -------------------------------
+#
+import tensorflow as tf
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.relay.backend import Executor, Runtime
+from tvm.contrib.download import download_testdata
+from tvm.micro import export_model_library_format
+from tvm.micro.model_library_format import generate_c_interface_header
+from tvm.micro.testing.utils import (
+    create_header_file,
+    mlf_extract_workspace_size_bytes,
+)
+
+######################################################################
+# Import Visual Wake Word Model
+# --------------------------------------------------------------------
+#
+# To begin with, download and import the Visual Wake Word (VWW) TFLite model from MLPerfTiny.
+# This model is originally from `MLPerf Tiny repository <https://github.com/mlcommons/tiny>`_.
+# We also capture metadata information from the TFLite model such as input/output name,
+# quantization parameters, etc. which will be used in following steps.
+#
+# We use indexing for various models to build the submission. The indices are defined as follows:
+# To build another model, you need to update the model URL, the short name and index number.
+#
+#   * Keyword Spotting(KWS) 1
+#   * Visual Wake Word(VWW) 2
+#   * Anomaly Detection(AD) 3
+#   * Image Classification(IC) 4
+#
+# If you would like to build the submission with CMSIS-NN, modify USE_CMSIS environment variable.
+#
+#   .. code-block:: bash
+#
+#     export USE_CMSIS=1
+#
+
+MODEL_URL = "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite"
+MODEL_PATH = download_testdata(MODEL_URL, "vww_96_int8.tflite", module="model")
+
+MODEL_SHORT_NAME = "VWW"
+MODEL_INDEX = 2
+
+USE_CMSIS = os.environ.get("TVM_USE_CMSIS", False)
+
+tflite_model_buf = open(MODEL_PATH, "rb").read()
+try:
+    import tflite
+
+    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+except AttributeError:
+    import tflite.Model
+
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+
+interpreter = tf.lite.Interpreter(model_path=str(MODEL_PATH))
+interpreter.allocate_tensors()
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+
+input_name = input_details[0]["name"]
+input_shape = tuple(input_details[0]["shape"])
+input_dtype = np.dtype(input_details[0]["dtype"]).name
+output_name = output_details[0]["name"]
+output_shape = tuple(output_details[0]["shape"])
+output_dtype = np.dtype(output_details[0]["dtype"]).name
+
+# We extract quantization information from TFLite model.
+# This is required for all models except Anomaly Detection,
+# because for other models we send quantized data to interpreter
+# from host, however, for AD model we send floating data and quantization
+# happens on the microcontroller.
+if MODEL_SHORT_NAME != "AD":
+    quant_output_scale = output_details[0]["quantization_parameters"]["scales"][0]
+    quant_output_zero_point = output_details[0]["quantization_parameters"]["zero_points"][0]
+
+relay_mod, params = relay.frontend.from_tflite(
+    tflite_model, shape_dict={input_name: input_shape}, dtype_dict={input_name: input_dtype}
+)
+
+######################################################################
+# Defining Target, Runtime and Executor
+# --------------------------------------------------------------------
+#
+# Now we need to define the target, runtime and executor to compile this model. In this tutorial,
+# we use Ahead-of-Time (AoT) compilation and we build a standalone project. This is different
+# than using AoT with host-driven mode where the target would communicate with host using host-driven
+# AoT executor to run inference.
+#
+
+# Use the C runtime (crt)
+RUNTIME = Runtime("crt")
+
+# Use the AoT executor with `unpacked-api=True` and `interface-api=c`. `interface-api=c` forces
+# the compiler to generate C type function APIs and `unpacked-api=True` forces the compiler
+# to generate minimal unpacked format inputs which reduces the stack memory usage on calling
+# inference layers of the model.
+EXECUTOR = Executor(
+    "aot",
+    {"unpacked-api": True, "interface-api": "c", "workspace-byte-alignment": 8},
+)
+
+# Select a Zephyr board
+BOARD = os.getenv("TVM_MICRO_BOARD", default="nucleo_l4r5zi")
+
+# Get the the full target description using the BOARD
+TARGET = tvm.micro.testing.get_target("zephyr", BOARD)
+
+######################################################################
+# Compile the model and export model library format
+# --------------------------------------------------------------------
+#
+# Now, we compile the model for the target. Then, we generate model
+# library format for the compiled model. We also need to calculate the
+# workspace size that is required for the compiled model.
+#
+#
+
+config = {"tir.disable_vectorize": True}
+if USE_CMSIS:
+    from tvm.relay.op.contrib import cmsisnn
+
+    config["relay.ext.cmsisnn.options"] = {"mcpu": TARGET.mcpu}
+    relay_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params, mcpu=TARGET.mcpu)
+
+with tvm.transform.PassContext(opt_level=3, config=config):
+    module = tvm.relay.build(
+        relay_mod, target=TARGET, params=params, runtime=RUNTIME, executor=EXECUTOR
+    )
+
+temp_dir = tvm.contrib.utils.tempdir()
+model_tar_path = temp_dir / "model.tar"
+export_model_library_format(module, model_tar_path)
+workspace_size = mlf_extract_workspace_size_bytes(model_tar_path)
+
+######################################################################
+# Generate input/output header files
+# --------------------------------------------------------------------
+#
+# To create a microTVM standalone project with AoT, we need to generate
+# input and output header files. These header files are used to connect
+# the input and output API from generated code to the rest of the
+# standalone project. For this specific submission, we only need to generate
+# output header file since the input API call is handled differently.
+#
+
+extra_tar_dir = tvm.contrib.utils.tempdir()
+extra_tar_file = extra_tar_dir / "extra.tar"
+
+with tarfile.open(extra_tar_file, "w:gz") as tf:
+    with tempfile.TemporaryDirectory() as tar_temp_dir:
+        model_files_path = os.path.join(tar_temp_dir, "include")
+        os.mkdir(model_files_path)
+        header_path = generate_c_interface_header(
+            module.libmod_name, [input_name], [output_name], [], {}, [], 0, model_files_path, {}, {}
+        )
+        tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
+
+    create_header_file(
+        "output_data",
+        np.zeros(
+            shape=output_shape,
+            dtype=output_dtype,
+        ),
+        "include",
+        tf,
+    )
+
+######################################################################
+# Create the project, build and prepare the project tar file
+# --------------------------------------------------------------------
+#
+# Now that we have the compiled model as a model library format,
+# we can generate the full project using Zephyr template project. First,
+# we prepare the project options, then build the project. Finally, we
+# cleanup the temporary files and move the submission project to the
+# current working directory which could be downloaded and used on
+# your development kit.
+#
+
+input_total_size = 1
+for i in range(len(input_shape)):
+    input_total_size *= input_shape[i]
+
+template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr"))
+project_options = {
+    "extra_files_tar": str(extra_tar_file),
+    "project_type": "mlperftiny",
+    "board": BOARD,
+    "compile_definitions": [
+        f"-DWORKSPACE_SIZE={workspace_size + 512}",  # Memory workspace size, 512 is a temporary offset
+        # since the memory calculation is not accurate.
+        f"-DTARGET_MODEL={MODEL_INDEX}",  # Sets the model index for project compilation.
+        f"-DTH_MODEL_VERSION=EE_MODEL_VERSION_{MODEL_SHORT_NAME}01",  # Sets model version. This is required by MLPerfTiny API.
+        f"-DMAX_DB_INPUT_SIZE={input_total_size}",  # Max size of the input data array.
+    ],
+}
+
+if MODEL_SHORT_NAME != "AD":
+    project_options["compile_definitions"].append(f"-DOUT_QUANT_SCALE={quant_output_scale}")
+    project_options["compile_definitions"].append(f"-DOUT_QUANT_ZERO={quant_output_zero_point}")
+
+if USE_CMSIS:
+    project_options["compile_definitions"].append(f"-DCOMPILE_WITH_CMSISNN=1")
+
+# Note: You might need to adjust this based on the board that you are using.
+project_options["config_main_stack_size"] = 4000
+
+if USE_CMSIS:
+    project_options["cmsis_path"] = os.environ.get("CMSIS_PATH", "/content/cmsis")
+
+generated_project_dir = temp_dir / "project"
+
+project = tvm.micro.project.generate_project_from_mlf(
+    template_project_path, generated_project_dir, model_tar_path, project_options
+)
+project.build()
+
+# Cleanup the build directory and extra artifacts
+shutil.rmtree(generated_project_dir / "build")
+(generated_project_dir / "model.tar").unlink()
+
+project_tar_path = pathlib.Path(os.getcwd()) / "project.tar"
+with tarfile.open(project_tar_path, "w:tar") as tar:
+    tar.add(generated_project_dir, arcname=os.path.basename("project"))
+
+print(f"The generated project is located here: {project_tar_path}")
+
+######################################################################
+# Use this project with your board
+# --------------------------------------------------------------------
+#
+# Now that we have the generated project, you can use this project locally
+# to flash your board and prepare it for EEMBC runner software.
+# To do this follow these steps:
+#
+#   .. code-block:: bash
+#
+#     tar -xf project.tar
+#     cd project
+#     mkdir build
+#     cmake ..
+#     make -j2
+#     west flash
+#
+# Now you can connect your board to EEMBC runner using this
+# `instructions <https://github.com/eembc/energyrunner>`_
+# and benchmark this model on your board.
+#
diff --git a/python/tvm/micro/testing/utils.py b/python/tvm/micro/testing/utils.py
index 097fbf283a..170c576314 100644
--- a/python/tvm/micro/testing/utils.py
+++ b/python/tvm/micro/testing/utils.py
@@ -17,6 +17,7 @@
 
 """Defines the test methods used with microTVM."""
 
+import io
 from functools import lru_cache
 import json
 import logging
@@ -24,6 +25,7 @@ from pathlib import Path
 import tarfile
 import time
 from typing import Union
+import numpy as np
 
 import tvm
 from tvm import relay
@@ -102,7 +104,7 @@ def mlf_extract_workspace_size_bytes(mlf_tar_path: Union[Path, str]) -> int:
 
     workspace_size = 0
     with tarfile.open(mlf_tar_path, "r:*") as tar_file:
-        tar_members = [ti.name for ti in tar_file.getmembers()]
+        tar_members = [tar_info.name for tar_info in tar_file.getmembers()]
         assert "./metadata.json" in tar_members
         with tar_file.extractfile("./metadata.json") as f:
             metadata = json.load(f)
@@ -133,3 +135,43 @@ def get_conv2d_relay_module():
     mod = tvm.IRModule.from_expr(f)
     mod = relay.transform.InferType()(mod)
     return mod
+
+
+def _npy_dtype_to_ctype(data: np.ndarray) -> str:
+    if data.dtype == "int8":
+        return "int8_t"
+    elif data.dtype == "int32":
+        return "int32_t"
+    elif data.dtype == "uint8":
+        return "uint8_t"
+    elif data.dtype == "float32":
+        return "float"
+    else:
+        raise ValueError(f"Data type {data.dtype} not expected.")
+
+
+def create_header_file(tensor_name: str, npy_data: np.array, output_path: str, tar_file: str):
+    """
+    This method generates a header file containing the data contained in the numpy array provided
+    and adds the header file to a tar file.
+    It is used to capture the tensor data (for both inputs and output).
+    """
+    header_file = io.StringIO()
+    header_file.write("#include <stddef.h>\n")
+    header_file.write("#include <stdint.h>\n")
+    header_file.write("#include <dlpack/dlpack.h>\n")
+    header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n")
+    header_file.write(f"{_npy_dtype_to_ctype(npy_data)} {tensor_name}[] =")
+
+    header_file.write("{")
+    for i in np.ndindex(npy_data.shape):
+        header_file.write(f"{npy_data[i]}, ")
+    header_file.write("};\n\n")
+
+    header_file_bytes = bytes(header_file.getvalue(), "utf-8")
+    raw_path = Path(output_path) / f"{tensor_name}.h"
+    tar_info = tarfile.TarInfo(name=str(raw_path))
+    tar_info.size = len(header_file_bytes)
+    tar_info.mode = 0o644
+    tar_info.type = tarfile.REGTYPE
+    tar_file.addfile(tar_info, io.BytesIO(header_file_bytes))
diff --git a/tests/micro/zephyr/utils.py b/tests/micro/zephyr/utils.py
index 42419b637f..bdac4e9c63 100644
--- a/tests/micro/zephyr/utils.py
+++ b/tests/micro/zephyr/utils.py
@@ -32,6 +32,7 @@ import requests
 import tvm.micro
 from tvm.micro import export_model_library_format
 from tvm.micro.model_library_format import generate_c_interface_header
+from tvm.micro.testing.utils import create_header_file
 from tvm.micro.testing.utils import (
     mlf_extract_workspace_size_bytes,
     aot_transport_init_wait,
@@ -106,42 +107,6 @@ def build_project(
     return project, project_dir
 
 
-def create_header_file(tensor_name, npy_data, output_path, tar_file):
-    """
-    This method generates a header file containing the data contained in the numpy array provided.
-    It is used to capture the tensor data (for both inputs and expected outputs).
-    """
-    header_file = io.StringIO()
-    header_file.write("#include <stddef.h>\n")
-    header_file.write("#include <stdint.h>\n")
-    header_file.write("#include <dlpack/dlpack.h>\n")
-    header_file.write(f"const size_t {tensor_name}_len = {npy_data.size};\n")
-
-    if npy_data.dtype == "int8":
-        header_file.write(f"int8_t {tensor_name}[] =")
-    elif npy_data.dtype == "int32":
-        header_file.write(f"int32_t {tensor_name}[] = ")
-    elif npy_data.dtype == "uint8":
-        header_file.write(f"uint8_t {tensor_name}[] = ")
-    elif npy_data.dtype == "float32":
-        header_file.write(f"float {tensor_name}[] = ")
-    else:
-        raise ValueError("Data type not expected.")
-
-    header_file.write("{")
-    for i in np.ndindex(npy_data.shape):
-        header_file.write(f"{npy_data[i]}, ")
-    header_file.write("};\n\n")
-
-    header_file_bytes = bytes(header_file.getvalue(), "utf-8")
-    raw_path = pathlib.Path(output_path) / f"{tensor_name}.h"
-    ti = tarfile.TarInfo(name=str(raw_path))
-    ti.size = len(header_file_bytes)
-    ti.mode = 0o644
-    ti.type = tarfile.REGTYPE
-    tar_file.addfile(ti, io.BytesIO(header_file_bytes))
-
-
 # TODO move CMSIS integration to microtvm_api_server.py
 # see https://discuss.tvm.apache.org/t/tvm-capturing-dependent-libraries-of-code-generated-tir-initially-for-use-in-model-library-format/11080
 def loadCMSIS(temp_dir):
diff --git a/tests/scripts/request_hook/request_hook.py b/tests/scripts/request_hook/request_hook.py
index 4e3db220e0..b033f1ca84 100644
--- a/tests/scripts/request_hook/request_hook.py
+++ b/tests/scripts/request_hook/request_hook.py
@@ -208,6 +208,7 @@ URL_MAP = {
     "https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5": f"{BASE}/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5",
     "https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5": f"{BASE}/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels.h5",
     "https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz": f"{BASE}/tensorflow/tf-keras-datasets/mnist.npz",
+    "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite": f"{BASE}/mlcommons/tiny/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite",
 }
 
 
diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh
index 6153cdf823..0b43c9c1fa 100755
--- a/tests/scripts/task_python_microtvm.sh
+++ b/tests/scripts/task_python_microtvm.sh
@@ -51,6 +51,13 @@ python3 gallery/how_to/work_with_microtvm/micro_aot.py
 python3 gallery/how_to/work_with_microtvm/micro_pytorch.py
 ./gallery/how_to/work_with_microtvm/micro_tvmc.sh
 
+# without CMSIS-NN
+python3 gallery/how_to/work_with_microtvm/micro_mlperftiny.py
+# with CMSIS-NN
+export TVM_USE_CMSIS=1
+python3 gallery/how_to/work_with_microtvm/micro_mlperftiny.py
+export TVM_USE_CMSIS=
+
 # Tutorials running with Zephyr
 export TVM_MICRO_USE_HW=1
 export TVM_MICRO_BOARD=qemu_x86