You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by "mehrdadh (via GitHub)" <gi...@apache.org> on 2023/03/13 18:37:21 UTC

[GitHub] [tvm] mehrdadh commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

mehrdadh commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1134420805


##########
apps/microtvm/gemmini/template_project/crt_config/crt_config.h:
##########
@@ -0,0 +1,57 @@
+/*

Review Comment:
   We no longer check in the crt_config.h per project, instead you can use https://github.com/apache/tvm/blob/main/cmake/utils/CRTConfig.cmake to generate crt_config for this project at build time.
   Here is the PR that made this change: https://github.com/apache/tvm/pull/13955
   



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import atexit
+import collections
+import functools
+import json
+import logging
+import os
+import os.path
+import pathlib
+import re
+import shlex
+import shutil
+import shlex, subprocess
+import sys
+import tarfile
+import tempfile
+import time
+from string import Template
+import re
+from distutils.dir_util import copy_tree
+import subprocess
+import serial
+
+# import serial.tools.list_ports
+from tvm.micro.project_api import server
+
+from subprocess import PIPE
+
+_LOG = logging.getLogger(__name__)
+
+MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar"
+API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd())
+BUILD_DIR = API_SERVER_DIR / "build"
+MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH
+
+IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
+
+PROJECT_TYPES = [
+    "dense_example",
+    "conv2d_example",
+    "dwconv2d_example",
+    "add_example",
+    "maxpool2d_example",
+    "mobilenet_example",
+]
+
+PROJECT_OPTIONS = [

Review Comment:
   Please follow the same convention in other api servers: https://github.com/apache/tvm/blob/9a99fc89a2970b9fca151a573de7a5e409b5d9ee/apps/microtvm/zephyr/template_project/microtvm_api_server.py#LL299C50-L299C50



##########
apps/microtvm/gemmini/template_project/src/Makefile:
##########
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   It is highly recommended to use CMakeFile instead of Makefile to make it cross-platform compatible.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import tvm._ffi.base
+
+from tvm.relay.backend.contrib.gemmini import *
+from .environment import Environment
+from .build_module import build_config, lower, build, preprocess_pass
+from .helpers import create_header_file
+from .utils import *
+
+__version__ = "0.1.0"

Review Comment:
   remove?



##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")
+
+try:
+    import tflite
+
+    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+except AttributeError:
+    import tflite.Model
+
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+
+# Load the TFLite model and allocate tensors.
+interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)
+interpreter.allocate_tensors()
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+tensor_details = interpreter.get_tensor_details()
+
+input_matrix_1 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+input_matrix_2 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+
+interpreter.set_tensor(input_details[0]["index"], input_matrix_1)
+interpreter.set_tensor(input_details[1]["index"], input_matrix_2)
+
+interpreter.invoke()
+expected_output = interpreter.get_tensor(output_details[0]["index"])
+
+# Here, we create C files and headers with the inputs and expected output, so that we can then execute the same operation on the Gemmini accelerator, and compare the expected output with the actual predicted one.
+gemmini.create_header_file("inputs", "data", "input_1", input_matrix_2, "./include")

Review Comment:
   You could pass the model library format file and add generated files to the `model.tar` file. This way you don't need to do extra steps to add the generated header to your project.



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   I don't think this is necessary, we only use this in tutorials.



##########
apps/microtvm/gemmini/template_project/src/add.c:
##########
@@ -0,0 +1,69 @@
+/*

Review Comment:
   src files for each project type should be in a separate sub-directory. Please follow the same convention.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   Please remove this from python files.



##########
apps/microtvm/gemmini/template_project/src/dwconv2d.c:
##########
@@ -0,0 +1,67 @@
+/*

Review Comment:
   same comment for all of these.



##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   Is it necessary to have a tutorial per operator? Wouldn't one tutorial show the general approach?



##########
python/tvm/contrib/gemmini/helpers.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Miscellaneous helpers
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import pathlib
+from typing import List
+from six.moves import range
+import numpy as np
+from .environment import Environment
+
+
+ENV = Environment.instance()
+
+
+def create_header_file(

Review Comment:
   why not reuse the existing `create_header_file` function in TVM?



##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single 2d convolutional layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized 2d convolution layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+"""
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+output_channels = 16
+kernel_size = 3
+stride = 1
+padding = "valid"
+activation = None
+bias = True
+
+# We can add a max pooling layer after the convolution. This can be merged by the integration and can be executed together with the convolution on the Gemmini accelerator.
+pool_size = 1
+pool_stride = 1
+pool_padding = "valid"
+use_pool = False
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+
+layer_sequence = [
+    layers.Conv2D(
+        output_channels,
+        kernel_size=kernel_size,
+        padding=padding,
+        activation=activation,
+        use_bias=True,
+        bias_initializer="ones",
+        input_shape=(input_height, input_width, input_channels),
+        strides=stride,
+    )
+]
+if use_pool:
+    layer_sequence.append(
+        layers.MaxPool2D(pool_size=pool_size, strides=pool_stride, padding=pool_padding)
+    )
+
+model = keras.Sequential(layer_sequence)
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        np.array(
+            np.random.randint(0, 10, size=(100, input_height, input_width, input_channels)),
+            dtype=np.float32,
+        )
+        for s in range(10)
+    ]
+    for input_value in dataset:
+        # Model has only one input so each data point has one element.s
+        yield [input_value]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("conv.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")

Review Comment:
   use `tvm.contrib.utils.tempdir()` for temporary files.



##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")

Review Comment:
   If you add generated header files to model.tar, you can avoid this step. This is very user unfriendly for a tutorial. Here is an example:
   https://github.com/apache/tvm/blob/9a99fc89a2970b9fca151a573de7a5e409b5d9ee/gallery/how_to/work_with_microtvm/micro_mlperftiny.py#L213
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org