You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2023/01/12 11:56:42 UTC

[GitHub] [tvm] fzi-peccia opened a new pull request, #13770: Gemmini code generation using microTVM

fzi-peccia opened a new pull request, #13770:
URL: https://github.com/apache/tvm/pull/13770

   Added integration to generate C code able to execute neural networks on the Gemmini accelerator. Information about this can be found on [this post](https://discuss.tvm.apache.org/t/presenting-the-generation-of-code-for-the-gemmini-accelerator/14107)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149196380


##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")

Review Comment:
   I added this method for all tutorials.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149201497


##########
apps/microtvm/gemmini/template_project/src/dwconv2d.c:
##########
@@ -0,0 +1,67 @@
+/*

Review Comment:
   Yes, I applied the suggestion for all of them.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   I removed it from all python files as requested.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149198727


##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single 2d convolutional layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized 2d convolution layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+"""
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+output_channels = 16
+kernel_size = 3
+stride = 1
+padding = "valid"
+activation = None
+bias = True
+
+# We can add a max pooling layer after the convolution. This can be merged by the integration and can be executed together with the convolution on the Gemmini accelerator.
+pool_size = 1
+pool_stride = 1
+pool_padding = "valid"
+use_pool = False
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+
+layer_sequence = [
+    layers.Conv2D(
+        output_channels,
+        kernel_size=kernel_size,
+        padding=padding,
+        activation=activation,
+        use_bias=True,
+        bias_initializer="ones",
+        input_shape=(input_height, input_width, input_channels),
+        strides=stride,
+    )
+]
+if use_pool:
+    layer_sequence.append(
+        layers.MaxPool2D(pool_size=pool_size, strides=pool_stride, padding=pool_padding)
+    )
+
+model = keras.Sequential(layer_sequence)
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        np.array(
+            np.random.randint(0, 10, size=(100, input_height, input_width, input_channels)),
+            dtype=np.float32,
+        )
+        for s in range(10)
+    ]
+    for input_value in dataset:
+        # Model has only one input so each data point has one element.s
+        yield [input_value]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("conv.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")

Review Comment:
   Thanks, I applied this suggestion for all examples for all temporal files.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import tvm._ffi.base
+
+from tvm.relay.backend.contrib.gemmini import *
+from .environment import Environment
+from .build_module import build_config, lower, build, preprocess_pass
+from .helpers import create_header_file
+from .utils import *
+
+__version__ = "0.1.0"

Review Comment:
   Removed version



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mehrdadh commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "mehrdadh (via GitHub)" <gi...@apache.org>.
mehrdadh commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1134420805


##########
apps/microtvm/gemmini/template_project/crt_config/crt_config.h:
##########
@@ -0,0 +1,57 @@
+/*

Review Comment:
   We no longer check in the crt_config.h per project, instead you can use https://github.com/apache/tvm/blob/main/cmake/utils/CRTConfig.cmake to generate crt_config for this project at build time.
   Here is the PR that made this change: https://github.com/apache/tvm/pull/13955
   



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import atexit
+import collections
+import functools
+import json
+import logging
+import os
+import os.path
+import pathlib
+import re
+import shlex
+import shutil
+import shlex, subprocess
+import sys
+import tarfile
+import tempfile
+import time
+from string import Template
+import re
+from distutils.dir_util import copy_tree
+import subprocess
+import serial
+
+# import serial.tools.list_ports
+from tvm.micro.project_api import server
+
+from subprocess import PIPE
+
+_LOG = logging.getLogger(__name__)
+
+MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar"
+API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd())
+BUILD_DIR = API_SERVER_DIR / "build"
+MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH
+
+IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
+
+PROJECT_TYPES = [
+    "dense_example",
+    "conv2d_example",
+    "dwconv2d_example",
+    "add_example",
+    "maxpool2d_example",
+    "mobilenet_example",
+]
+
+PROJECT_OPTIONS = [

Review Comment:
   Please follow the same convention in other api servers: https://github.com/apache/tvm/blob/9a99fc89a2970b9fca151a573de7a5e409b5d9ee/apps/microtvm/zephyr/template_project/microtvm_api_server.py#LL299C50-L299C50



##########
apps/microtvm/gemmini/template_project/src/Makefile:
##########
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   It is highly recommended to use CMakeFile instead of Makefile to make it cross-platform compatible.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import tvm._ffi.base
+
+from tvm.relay.backend.contrib.gemmini import *
+from .environment import Environment
+from .build_module import build_config, lower, build, preprocess_pass
+from .helpers import create_header_file
+from .utils import *
+
+__version__ = "0.1.0"

Review Comment:
   remove?



##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")
+
+try:
+    import tflite
+
+    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+except AttributeError:
+    import tflite.Model
+
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+
+# Load the TFLite model and allocate tensors.
+interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)
+interpreter.allocate_tensors()
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+tensor_details = interpreter.get_tensor_details()
+
+input_matrix_1 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+input_matrix_2 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+
+interpreter.set_tensor(input_details[0]["index"], input_matrix_1)
+interpreter.set_tensor(input_details[1]["index"], input_matrix_2)
+
+interpreter.invoke()
+expected_output = interpreter.get_tensor(output_details[0]["index"])
+
+# Here, we create C files and headers with the inputs and expected output, so that we can then execute the same operation on the Gemmini accelerator, and compare the expected output with the actual predicted one.
+gemmini.create_header_file("inputs", "data", "input_1", input_matrix_2, "./include")

Review Comment:
   You could pass the model library format file and add generated files to the `model.tar` file. This way you don't need to do extra steps to add the generated header to your project.



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   I don't think this is necessary, we only use this in tutorials.



##########
apps/microtvm/gemmini/template_project/src/add.c:
##########
@@ -0,0 +1,69 @@
+/*

Review Comment:
   src files for each project type should be in a separate sub-directory. Please follow the same convention.



##########
python/tvm/contrib/gemmini/__init__.py:
##########
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Gemmini package is a TVM backend extension to support the Gemmini hardware accelerator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   Please remove this from python files.



##########
apps/microtvm/gemmini/template_project/src/dwconv2d.c:
##########
@@ -0,0 +1,67 @@
+/*

Review Comment:
   same comment for all of these.



##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   Is it necessary to have a tutorial per operator? Wouldn't one tutorial show the general approach?



##########
python/tvm/contrib/gemmini/helpers.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Miscellaneous helpers
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import pathlib
+from typing import List
+from six.moves import range
+import numpy as np
+from .environment import Environment
+
+
+ENV = Environment.instance()
+
+
+def create_header_file(

Review Comment:
   why not reuse the existing `create_header_file` function in TVM?



##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single 2d convolutional layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized 2d convolution layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+"""
+
+import tensorflow as tf
+from tensorflow import keras
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+output_channels = 16
+kernel_size = 3
+stride = 1
+padding = "valid"
+activation = None
+bias = True
+
+# We can add a max pooling layer after the convolution. This can be merged by the integration and can be executed together with the convolution on the Gemmini accelerator.
+pool_size = 1
+pool_stride = 1
+pool_padding = "valid"
+use_pool = False
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+
+layer_sequence = [
+    layers.Conv2D(
+        output_channels,
+        kernel_size=kernel_size,
+        padding=padding,
+        activation=activation,
+        use_bias=True,
+        bias_initializer="ones",
+        input_shape=(input_height, input_width, input_channels),
+        strides=stride,
+    )
+]
+if use_pool:
+    layer_sequence.append(
+        layers.MaxPool2D(pool_size=pool_size, strides=pool_stride, padding=pool_padding)
+    )
+
+model = keras.Sequential(layer_sequence)
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        np.array(
+            np.random.randint(0, 10, size=(100, input_height, input_width, input_channels)),
+            dtype=np.float32,
+        )
+        for s in range(10)
+    ]
+    for input_value in dataset:
+        # Model has only one input so each data point has one element.s
+        yield [input_value]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("conv.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")

Review Comment:
   use `tvm.contrib.utils.tempdir()` for temporary files.



##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")

Review Comment:
   If you add generated header files to model.tar, you can avoid this step. This is very user unfriendly for a tutorial. Here is an example:
   https://github.com/apache/tvm/blob/9a99fc89a2970b9fca151a573de7a5e409b5d9ee/gallery/how_to/work_with_microtvm/micro_mlperftiny.py#L213
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] tvm-bot commented on pull request #13770: Gemmini code generation using microTVM

Posted by GitBox <gi...@apache.org>.
tvm-bot commented on PR #13770:
URL: https://github.com/apache/tvm/pull/13770#issuecomment-1380219284

   <!---bot-comment-->
   
   Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from [Reviewers](https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers) by @-ing them in a comment.
   
   <!--bot-comment-ccs-start-->
    * No users to auto-tag found, no teams are specified in PR title <sub>See [#10317](https://github.com/apache/tvm/issues/10317) for details</sub><!--bot-comment-ccs-end-->
   
   <sub>Generated by [tvm-bot](https://github.com/apache/tvm/blob/main/ci/README.md#github-actions)</sub>


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on PR #13770:
URL: https://github.com/apache/tvm/pull/13770#issuecomment-1512500436

   Hi @mehrdadh, all tests have passed except these two:
   
   - cortexm/pr-head: some problem with zephyr and mlfperf-tiny, this has nothing to do with my changes.
   - gpu/pr-head: some problem building the tutorials documentation, do you know a workaround to solve this?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mehrdadh commented on pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by GitBox <gi...@apache.org>.
mehrdadh commented on PR #13770:
URL: https://github.com/apache/tvm/pull/13770#issuecomment-1386094125

   @fPecc FYI, I have plan to review this PR this week.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] kimjungwow commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "kimjungwow (via GitHub)" <gi...@apache.org>.
kimjungwow commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1127773350


##########
cmake/modules/contrib/Gemmini.cmake:
##########
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(USE_GEMMINI)
+  message(STATUS "Add Gemmini for microTVM")
+
+  function(microtvm_add_gemmini)
+    list(
+      APPEND
+      GEMMINI_FILE_COPY_JOBS
+      "apps/microtvm/gemmini/template_project microtvm_api_server.py -> gemmini"
+      "apps/microtvm/gemmini/template_project/crt_config *.h -> gemmini/crt_config"
+
+      # Dense example project generation
+      "apps/microtvm/gemmini/template_project/src dense.c -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/dense_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/dense_example/rocc-software/src"
+
+      # CONV2D example project generation
+      "apps/microtvm/gemmini/template_project/src conv2d.c -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/conv2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/conv2d_example/rocc-software/src"
+
+      # DW CONV2D example project generation
+      "apps/microtvm/gemmini/template_project/src dwconv2d.c -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/dwconv2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/dwconv2d_example/rocc-software/src"
+
+      # ADD example project generation
+      "apps/microtvm/gemmini/template_project/src add.c -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/add_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/add_example/rocc-software/src"
+
+      # Max pooling 2d example project generation
+      "apps/microtvm/gemmini/template_project/src maxpool2d.c -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/maxpool2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/maxpool2d_example/rocc-software/src"
+
+      # Mobilenet example project generation
+      "apps/microtvm/gemmini/template_project/src mobilenet.c -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/mobilenet_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/mobilenet_example/rocc-software/src"
+    )
+
+    foreach(job_spec IN LISTS GEMMINI_FILE_COPY_JOBS)
+      string(REPLACE " " ";" job_spec "${job_spec}")
+      list(LENGTH job_spec job_spec_length)
+      math(EXPR job_spec_length_mod "${job_spec_length} % 3")
+      if(NOT "${job_spec_length_mod}" EQUAL 1)
+        message(
+          FATAL_ERROR
+            "Gemmini copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}"
+        )
+      endif()
+      math(EXPR job_spec_stop "${job_spec_length} - 3")
+
+      list(GET job_spec 0 job_src_base)
+      set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}")
+      foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3)
+        list(GET job_spec ${copy_pattern_index} copy_pattern)
+        math(EXPR copy_dest_index "${copy_pattern_index} + 2")
+        list(GET job_spec ${copy_dest_index} copy_dest)
+
+        file(
+          GLOB_RECURSE copy_files
+          RELATIVE "${job_src_base}"
+          "${job_src_base}/${copy_pattern}")
+        list(LENGTH copy_files copy_files_length)
+        if("${copy_files_length}" EQUAL 0)
+          message(
+            FATAL_ERROR
+              "Gemmini copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}"
+          )
+        endif()
+        foreach(copy_src IN LISTS copy_files)
+          get_filename_component(
+            dest_path "${MICROTVM_TEMPLATE_PROJECTS}/${copy_dest}/${copy_src}"
+            ABSOLUTE)
+          tvm_micro_add_copy_file(gemmini_template_deps
+                                  ${job_src_base}/${copy_src} ${dest_path})
+        endforeach()
+      endforeach()
+    endforeach()
+
+    add_custom_target(gemmini DEPENDS ${gemmini_template_deps})
+  endfunction()
+
+  microtvm_add_gemmini()
+
+endif(USE_MICRO)

Review Comment:
   Hi. I have a small suggestion.
   
   ```suggestion
   endif(USE_GEMMINI)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1148911732


##########
gallery/tutorial/micro_gemmini_conv2d.py:
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one

Review Comment:
   We can have just one example for an entire network, the one for MobileNet, and that would be enough to show that the integration works. But having one per operator allows to understand better the internal workings of the code, and gives some debug capabilities, because one can run specific layers with specific parameters to see if they work or not.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] mehrdadh commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by GitBox <gi...@apache.org>.
mehrdadh commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1072913504


##########
apps/microtvm/gemmini/template_project/src/makefiles/conv2d/Makefile:
##########
@@ -0,0 +1,68 @@
+include $(abs_top_srcdir)/Makefrag

Review Comment:
   Could you consolidate all the Makefiles to a Makefile.template and modify it based on the project type in `generate_project` step?



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import atexit
+import collections
+import functools
+import json
+import logging
+import os
+import os.path
+import pathlib
+import re
+import shlex
+import shutil
+import shlex, subprocess
+import sys
+import tarfile
+import tempfile
+import time
+from string import Template
+import re
+from distutils.dir_util import copy_tree
+import subprocess
+import serial
+
+# import serial.tools.list_ports
+from tvm.micro.project_api import server
+
+from subprocess import PIPE
+
+_LOG = logging.getLogger(__name__)
+
+MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar"
+API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd())
+BUILD_DIR = API_SERVER_DIR / "build"
+MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH
+
+IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
+
+PROJECT_TYPES = [
+    "dense_example",
+    "conv2d_example",
+    "dwconv2d_example",
+    "add_example",
+    "maxpool2d_example",
+    "mobilenet_example",
+]
+
+PROJECT_OPTIONS = [
+    server.ProjectOption(
+        "project_type",
+        required=["generate_project"],
+        choices=tuple(PROJECT_TYPES),
+        type="str",
+        help="Type of project to generate.",
+    )
+]
+
+
+class Handler(server.ProjectAPIHandler):
+    def __init__(self):
+        super(Handler, self).__init__()
+        self._proc = None
+        self._port = None
+        self._transport = None
+        self._project_dir = None
+        self._qemu_instance = None
+
+    def server_info_query(self, tvm_version):
+        return server.ServerInfo(
+            platform_name="gemmini",
+            is_template=IS_TEMPLATE,
+            model_library_format_path="" if IS_TEMPLATE else MODEL_LIBRARY_FORMAT_PATH,
+            project_options=PROJECT_OPTIONS,
+        )
+
+    def _copy_project_files(self, api_server_dir, project_dir, project_type):
+        """Copies the files for project_type into project_dir.
+
+        Notes
+        -----
+        template_dir is NOT a project type, and that directory is never copied
+        in this function. template_dir only holds this file and its unit tests,
+        so this file is copied separately in generate_project.
+
+        """
+        for item in (API_SERVER_DIR / "src" / project_type).iterdir():
+            dest = project_dir / "src" / item.name
+            if item.is_dir():
+                shutil.copytree(item, dest)
+            else:
+                shutil.copy2(item, dest)
+
+    CRT_COPY_ITEMS = ("include", "src")
+
+    def _copy_standalone_crt(self, source_dir, standalone_crt_dir):
+        output_crt_dir = source_dir / "standalone_crt"
+        for item in self.CRT_COPY_ITEMS:
+            src_path = os.path.join(standalone_crt_dir, item)
+            dst_path = output_crt_dir / item
+            if os.path.isdir(src_path):
+                shutil.copytree(src_path, dst_path)
+            else:
+                shutil.copy2(src_path, dst_path)
+
+    # Example project is the "minimum viable project",
+    # and doesn't need a fancy RPC server
+    EXAMPLE_PROJECT_UNUSED_COMPONENTS = []
+
+    def _remove_unused_components(self, source_dir, project_type):
+        unused_components = []
+        if project_type == "example_project":
+            unused_components = self.EXAMPLE_PROJECT_UNUSED_COMPONENTS
+
+        for component in unused_components:
+            shutil.rmtree(source_dir / "standalone_crt" / component)
+
+    def _disassemble_mlf(self, mlf_tar_path, source_dir):
+        with tempfile.TemporaryDirectory() as mlf_unpacking_dir_str:
+            mlf_unpacking_dir = pathlib.Path(mlf_unpacking_dir_str)
+            with tarfile.open(mlf_tar_path, "r:") as tar:
+                tar.extractall(mlf_unpacking_dir)
+
+            model_dir = source_dir / "model"
+            model_dir.mkdir()
+
+            # Copy C files from model. The filesnames and quantity
+            # depend on the target string, so we just copy all c files
+            source_dir = mlf_unpacking_dir / "codegen" / "host" / "src"
+            for file in source_dir.rglob(f"*.c"):
+                shutil.copy(file, model_dir)
+
+            source_dir = mlf_unpacking_dir / "codegen" / "host" / "include"
+            for file in source_dir.rglob(f"*.h"):
+                shutil.copy(file, model_dir)
+
+            # Return metadata.json for use in templating
+            with open(os.path.join(mlf_unpacking_dir, "metadata.json")) as f:
+                metadata = json.load(f)
+        return metadata
+
+    def _template_model_header(self, source_dir, metadata):
+        with open(source_dir / "model.h", "r") as f:
+            model_h_template = Template(f.read())
+
+        assert (
+            metadata["style"] == "full-model"
+        ), "when generating AOT, expect only full-model Model Library Format"
+
+        template_values = {
+            "workspace_size_bytes": metadata["memory"]["functions"]["main"][0][
+                "workspace_size_bytes"
+            ],
+        }
+
+        with open(source_dir / "model.h", "w") as f:
+            f.write(model_h_template.substitute(template_values))
+
+    # Arduino ONLY recognizes .ino, .ccp, .c, .h

Review Comment:
   remove?



##########
cmake/modules/contrib/Gemmini.cmake:
##########
@@ -0,0 +1,117 @@
+if(USE_MICRO)

Review Comment:
   I think this should be a separate flag which is disabled by default, maybe use `USE_GEMMINI`



##########
apps/microtvm/gemmini/README.md:
##########
@@ -0,0 +1,3 @@
+This directory contains code to create code for the Gemmini accelerator using microTVM. These tests are then executed on the Spike RISC-V ISA simulator.
+
+In order to use this correctly, the Spike simulator has to be installed. This can be done by following the steps found on the Chipyard repository.

Review Comment:
   Link to instruction is missing



##########
python/tvm/contrib/gemmini/tutorials/single_operators/add-tutorial.ipynb:
##########
@@ -0,0 +1,395 @@
+{

Review Comment:
   tutorial files should move to somewhere under `gallery/how_to/`. Also you need to change the format to .py file and write it in sphinx format. Now we support notebook generation and google colab, so you can even add cells to install all the dependencies and run it in google colab



##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import atexit
+import collections
+import functools
+import json
+import logging
+import os
+import os.path
+import pathlib
+import re
+import shlex
+import shutil
+import shlex, subprocess
+import sys
+import tarfile
+import tempfile
+import time
+from string import Template
+import re
+from distutils.dir_util import copy_tree
+import subprocess
+import serial
+
+# import serial.tools.list_ports
+from tvm.micro.project_api import server
+
+from subprocess import PIPE
+
+_LOG = logging.getLogger(__name__)
+
+MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar"
+API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd())
+BUILD_DIR = API_SERVER_DIR / "build"
+MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH
+
+IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
+
+PROJECT_TYPES = [
+    "dense_example",
+    "conv2d_example",
+    "dwconv2d_example",
+    "add_example",
+    "maxpool2d_example",
+    "mobilenet_example",
+]
+
+PROJECT_OPTIONS = [
+    server.ProjectOption(
+        "project_type",
+        required=["generate_project"],
+        choices=tuple(PROJECT_TYPES),
+        type="str",
+        help="Type of project to generate.",
+    )
+]
+
+
+class Handler(server.ProjectAPIHandler):
+    def __init__(self):
+        super(Handler, self).__init__()
+        self._proc = None
+        self._port = None
+        self._transport = None
+        self._project_dir = None
+        self._qemu_instance = None
+
+    def server_info_query(self, tvm_version):
+        return server.ServerInfo(
+            platform_name="gemmini",
+            is_template=IS_TEMPLATE,
+            model_library_format_path="" if IS_TEMPLATE else MODEL_LIBRARY_FORMAT_PATH,
+            project_options=PROJECT_OPTIONS,
+        )
+
+    def _copy_project_files(self, api_server_dir, project_dir, project_type):
+        """Copies the files for project_type into project_dir.
+
+        Notes
+        -----
+        template_dir is NOT a project type, and that directory is never copied
+        in this function. template_dir only holds this file and its unit tests,
+        so this file is copied separately in generate_project.
+
+        """
+        for item in (API_SERVER_DIR / "src" / project_type).iterdir():
+            dest = project_dir / "src" / item.name
+            if item.is_dir():
+                shutil.copytree(item, dest)
+            else:
+                shutil.copy2(item, dest)
+
+    CRT_COPY_ITEMS = ("include", "src")
+
+    def _copy_standalone_crt(self, source_dir, standalone_crt_dir):
+        output_crt_dir = source_dir / "standalone_crt"
+        for item in self.CRT_COPY_ITEMS:
+            src_path = os.path.join(standalone_crt_dir, item)
+            dst_path = output_crt_dir / item
+            if os.path.isdir(src_path):
+                shutil.copytree(src_path, dst_path)
+            else:
+                shutil.copy2(src_path, dst_path)
+
+    # Example project is the "minimum viable project",
+    # and doesn't need a fancy RPC server
+    EXAMPLE_PROJECT_UNUSED_COMPONENTS = []
+
+    def _remove_unused_components(self, source_dir, project_type):
+        unused_components = []
+        if project_type == "example_project":
+            unused_components = self.EXAMPLE_PROJECT_UNUSED_COMPONENTS
+
+        for component in unused_components:
+            shutil.rmtree(source_dir / "standalone_crt" / component)
+
+    def _disassemble_mlf(self, mlf_tar_path, source_dir):
+        with tempfile.TemporaryDirectory() as mlf_unpacking_dir_str:
+            mlf_unpacking_dir = pathlib.Path(mlf_unpacking_dir_str)
+            with tarfile.open(mlf_tar_path, "r:") as tar:
+                tar.extractall(mlf_unpacking_dir)
+
+            model_dir = source_dir / "model"
+            model_dir.mkdir()
+
+            # Copy C files from model. The filesnames and quantity
+            # depend on the target string, so we just copy all c files
+            source_dir = mlf_unpacking_dir / "codegen" / "host" / "src"
+            for file in source_dir.rglob(f"*.c"):
+                shutil.copy(file, model_dir)
+
+            source_dir = mlf_unpacking_dir / "codegen" / "host" / "include"
+            for file in source_dir.rglob(f"*.h"):
+                shutil.copy(file, model_dir)
+
+            # Return metadata.json for use in templating
+            with open(os.path.join(mlf_unpacking_dir, "metadata.json")) as f:
+                metadata = json.load(f)
+        return metadata
+
+    def _template_model_header(self, source_dir, metadata):
+        with open(source_dir / "model.h", "r") as f:
+            model_h_template = Template(f.read())
+
+        assert (
+            metadata["style"] == "full-model"
+        ), "when generating AOT, expect only full-model Model Library Format"
+
+        template_values = {
+            "workspace_size_bytes": metadata["memory"]["functions"]["main"][0][
+                "workspace_size_bytes"
+            ],
+        }
+
+        with open(source_dir / "model.h", "w") as f:
+            f.write(model_h_template.substitute(template_values))
+
+    # Arduino ONLY recognizes .ino, .ccp, .c, .h
+
+    CPP_FILE_EXTENSION_SYNONYMS = ("cc", "cxx")
+
+    def _change_cpp_file_extensions(self, source_dir):
+        for ext in self.CPP_FILE_EXTENSION_SYNONYMS:
+            for filename in source_dir.rglob(f"*.{ext}"):
+                filename.rename(filename.with_suffix(".cpp"))
+
+        for filename in source_dir.rglob(f"*.inc"):
+            filename.rename(filename.with_suffix(".h"))
+
+    def _convert_includes(self, project_dir, source_dir):
+        """Changes all #include statements in project_dir to be relevant to their
+        containing file's location.
+
+        Arduino only supports includes relative to a file's location, so this

Review Comment:
   fix the function description



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on PR #13770:
URL: https://github.com/apache/tvm/pull/13770#issuecomment-1484569343

   Thanks for the feedback @mehrdadh. I will work on this changes this week and let you know when everything is applied
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149196854


##########
gallery/tutorial/micro_gemmini_add.py:
##########
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Running TVM on the Gemmini accelerator - A single add layer example
+======================================================================================
+**Author**:
+`Federico Peccia <https://fPecc.github.io/>`_
+
+This tutorials shows how a quantized add layer can be compiled to be executed on the Gemmini accelerator. The generated baremetal C code is then tested on the Spike RISC-V ISA simulator. Before starting this tutorial, you should have downloaded the Chipyard repository and installed the Spike simulator with the Gemmini extension.
+
+Note: This is an **experimental** layer!
+"""
+
+import tensorflow as tf
+from tensorflow.keras import layers
+import numpy as np
+import os
+import tvm.contrib.gemmini as gemmini
+from tvm import relay
+import tvm
+
+##################################
+# Pre-requisites
+# --------------------------------
+#
+# After the installation of the Chipyard development tools, you should have an env.sh file in your Chipyard home directory. This file needs to be sourced before running this tutorial:
+#
+# .. code-block:: bash
+#
+#   source <your chipyard home path>/env.sh
+#
+# WARNING: if you have installed TVM in a virtual environment, FIRST activate the Chipyard environment, and THEN activate the tvm entironment.
+
+##################################
+# Baseline generation
+# --------------------------------
+#
+# In this section, we will generate the baseline input and expected output, which we are going to use to compare with the actual obtained output after running on the Gemmini accelerator.
+
+# Then we define the parameters of the layer we want to test. In this case:
+input_height = 16
+input_width = 16
+input_channels = 16
+activation = 0
+
+# We will generate a prequantized TFLite model, because for now the Gemmini integration only supports models that were quantized with specific flags as input.
+class Model(tf.Module):
+    def __init__(self, name=None):
+        super().__init__(name)
+
+    @tf.function(
+        input_signature=[
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+            tf.TensorSpec(
+                shape=[1, input_height, input_width, input_channels],
+                dtype=tf.float32,
+            ),
+        ]
+    )
+    def add(self, x, y):
+        if activation == 0:
+            return x + y
+        else:
+            return layers.Activation("relu")(x + y)
+
+
+model = Model()
+
+# Convert the concrete functions using TFLiteConverter
+converter = tf.lite.TFLiteConverter.from_keras_model(model)
+
+
+def representative_data_gen():
+    dataset = [
+        (
+            np.array(
+                np.random.randint(-127, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+            np.array(
+                np.random.randint(0, 128, size=(1, input_height, input_width, input_channels)),
+                dtype=np.float32,
+            ),
+        )
+        for s in range(100)
+    ]
+    for input_value in dataset:
+        yield [input_value[0], input_value[1]]
+
+
+converter.optimizations = [tf.lite.Optimize.DEFAULT]
+converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+converter.inference_input_type = tf.uint8
+converter.inference_output_type = tf.int8
+converter.representative_dataset = representative_data_gen
+converter._experimental_disable_per_channel = True
+
+tflite_model = converter.convert()
+
+# Save the model.
+with open("add.tflite", "wb") as f:
+    f.write(tflite_model)
+
+# Now that we have created the model, we import the model and run it. We store the output, in order to compare it with the output that will be later obtained from the Gemmini accelerator.
+
+os.system("rm -rf model.tar dev/ include/ generated-project/")
+
+tflite_file = "./add.tflite"
+tflite_model_buf = open(tflite_file, "rb").read()
+input_tensor = "layer1_input"
+input_dtype = "uint8"
+
+os.system("mkdir -p include")
+
+try:
+    import tflite
+
+    tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+except AttributeError:
+    import tflite.Model
+
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+
+# Load the TFLite model and allocate tensors.
+interpreter = tf.lite.Interpreter(model_path=tflite_file, experimental_preserve_all_tensors=True)
+interpreter.allocate_tensors()
+input_details = interpreter.get_input_details()
+output_details = interpreter.get_output_details()
+tensor_details = interpreter.get_tensor_details()
+
+input_matrix_1 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+input_matrix_2 = np.random.randint(
+    0, 255, (1, input_height, input_width, input_channels), dtype=np.uint8
+)
+
+interpreter.set_tensor(input_details[0]["index"], input_matrix_1)
+interpreter.set_tensor(input_details[1]["index"], input_matrix_2)
+
+interpreter.invoke()
+expected_output = interpreter.get_tensor(output_details[0]["index"])
+
+# Here, we create C files and headers with the inputs and expected output, so that we can then execute the same operation on the Gemmini accelerator, and compare the expected output with the actual predicted one.
+gemmini.create_header_file("inputs", "data", "input_1", input_matrix_2, "./include")

Review Comment:
   As recommended, I added this method to all tutorials.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149200559


##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import atexit
+import collections
+import functools
+import json
+import logging
+import os
+import os.path
+import pathlib
+import re
+import shlex
+import shutil
+import shlex, subprocess
+import sys
+import tarfile
+import tempfile
+import time
+from string import Template
+import re
+from distutils.dir_util import copy_tree
+import subprocess
+import serial
+
+# import serial.tools.list_ports
+from tvm.micro.project_api import server
+
+from subprocess import PIPE
+
+_LOG = logging.getLogger(__name__)
+
+MODEL_LIBRARY_FORMAT_RELPATH = pathlib.Path("src") / "model" / "model.tar"
+API_SERVER_DIR = pathlib.Path(os.path.dirname(__file__) or os.path.getcwd())
+BUILD_DIR = API_SERVER_DIR / "build"
+MODEL_LIBRARY_FORMAT_PATH = API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH
+
+IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists()
+
+PROJECT_TYPES = [
+    "dense_example",
+    "conv2d_example",
+    "dwconv2d_example",
+    "add_example",
+    "maxpool2d_example",
+    "mobilenet_example",
+]
+
+PROJECT_OPTIONS = [

Review Comment:
   I applied the suggestion, by separating each example in a different folder, and using default_project_options to initialize the project's options.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fPecc commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by GitBox <gi...@apache.org>.
fPecc commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1082254341


##########
python/tvm/contrib/gemmini/tutorials/single_operators/add-tutorial.ipynb:
##########
@@ -0,0 +1,395 @@
+{

Review Comment:
   Thanks for the feedback, I am on it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149202288


##########
apps/microtvm/gemmini/template_project/crt_config/crt_config.h:
##########
@@ -0,0 +1,57 @@
+/*

Review Comment:
   Thanks, I changed it to use the generation at build time.



##########
cmake/modules/contrib/Gemmini.cmake:
##########
@@ -0,0 +1,134 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+if(USE_GEMMINI)
+  message(STATUS "Add Gemmini for microTVM")
+
+  function(microtvm_add_gemmini)
+    list(
+      APPEND
+      GEMMINI_FILE_COPY_JOBS
+      "apps/microtvm/gemmini/template_project microtvm_api_server.py -> gemmini"
+      "apps/microtvm/gemmini/template_project/crt_config *.h -> gemmini/crt_config"
+
+      # Dense example project generation
+      "apps/microtvm/gemmini/template_project/src dense.c -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/dense_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/dense_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/dense_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/dense_example/rocc-software/src"
+
+      # CONV2D example project generation
+      "apps/microtvm/gemmini/template_project/src conv2d.c -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/conv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/conv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/conv2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/conv2d_example/rocc-software/src"
+
+      # DW CONV2D example project generation
+      "apps/microtvm/gemmini/template_project/src dwconv2d.c -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/dwconv2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/dwconv2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/dwconv2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/dwconv2d_example/rocc-software/src"
+
+      # ADD example project generation
+      "apps/microtvm/gemmini/template_project/src add.c -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/add_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/add_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/add_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/add_example/rocc-software/src"
+
+      # Max pooling 2d example project generation
+      "apps/microtvm/gemmini/template_project/src maxpool2d.c -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/maxpool2d_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/maxpool2d_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/maxpool2d_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/maxpool2d_example/rocc-software/src"
+
+      # Mobilenet example project generation
+      "apps/microtvm/gemmini/template_project/src mobilenet.c -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefile -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefile.in -> gemmini/src/mobilenet_example"
+      "apps/microtvm/gemmini/template_project/src Makefrag.mk -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests build.sh -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests configure.ac -> gemmini/src/mobilenet_example"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/include *.h -> gemmini/src/mobilenet_example/include"
+      "3rdparty/gemmini/software/gemmini-rocc-tests/rocc-software/src *.h -> gemmini/src/mobilenet_example/rocc-software/src"
+    )
+
+    foreach(job_spec IN LISTS GEMMINI_FILE_COPY_JOBS)
+      string(REPLACE " " ";" job_spec "${job_spec}")
+      list(LENGTH job_spec job_spec_length)
+      math(EXPR job_spec_length_mod "${job_spec_length} % 3")
+      if(NOT "${job_spec_length_mod}" EQUAL 1)
+        message(
+          FATAL_ERROR
+            "Gemmini copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}"
+        )
+      endif()
+      math(EXPR job_spec_stop "${job_spec_length} - 3")
+
+      list(GET job_spec 0 job_src_base)
+      set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}")
+      foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3)
+        list(GET job_spec ${copy_pattern_index} copy_pattern)
+        math(EXPR copy_dest_index "${copy_pattern_index} + 2")
+        list(GET job_spec ${copy_dest_index} copy_dest)
+
+        file(
+          GLOB_RECURSE copy_files
+          RELATIVE "${job_src_base}"
+          "${job_src_base}/${copy_pattern}")
+        list(LENGTH copy_files copy_files_length)
+        if("${copy_files_length}" EQUAL 0)
+          message(
+            FATAL_ERROR
+              "Gemmini copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}"
+          )
+        endif()
+        foreach(copy_src IN LISTS copy_files)
+          get_filename_component(
+            dest_path "${MICROTVM_TEMPLATE_PROJECTS}/${copy_dest}/${copy_src}"
+            ABSOLUTE)
+          tvm_micro_add_copy_file(gemmini_template_deps
+                                  ${job_src_base}/${copy_src} ${dest_path})
+        endforeach()
+      endforeach()
+    endforeach()
+
+    add_custom_target(gemmini DEPENDS ${gemmini_template_deps})
+  endfunction()
+
+  microtvm_add_gemmini()
+
+endif(USE_MICRO)

Review Comment:
   Thanks, I commited the suggested change.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149198166


##########
python/tvm/contrib/gemmini/helpers.py:
##########
@@ -0,0 +1,181 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Miscellaneous helpers
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_
+"""
+
+import pathlib
+from typing import List
+from six.moves import range
+import numpy as np
+from .environment import Environment
+
+
+ENV = Environment.instance()
+
+
+def create_header_file(

Review Comment:
   I changed to use the standard create_header_file from tvm.micro.testing.utils, but I changed a line in it to generate a define, instead of a const. I think this should be a define, but if that is not the case, I will need to continue using my own create_header_file



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149201073


##########
apps/microtvm/gemmini/template_project/microtvm_api_server.py:
##########
@@ -0,0 +1,286 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+MicroTVM API Server for Gemmini baremetal tests on the Spike simulator
+=====================
+**Author**: `Federico Peccia <https://fPecc.github.io/>`_

Review Comment:
   I removed it from all python files as requested.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org


[GitHub] [tvm] fzi-peccia commented on a diff in pull request #13770: [microTVM]Gemmini code generation using microTVM

Posted by "fzi-peccia (via GitHub)" <gi...@apache.org>.
fzi-peccia commented on code in PR #13770:
URL: https://github.com/apache/tvm/pull/13770#discussion_r1149199231


##########
apps/microtvm/gemmini/template_project/src/add.c:
##########
@@ -0,0 +1,69 @@
+/*

Review Comment:
   Thanks, I applied the suggestion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org