You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/04/28 09:17:25 UTC

[tvm] branch main updated: [CMSIS-NN] Moved TFLite model making to common area (#10939)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 72e11baabb [CMSIS-NN] Moved TFLite model making to common area (#10939)
72e11baabb is described below

commit 72e11baabb0e3a7e311c4b3490b729641c489555
Author: Ashutosh Parkhi <86...@users.noreply.github.com>
AuthorDate: Thu Apr 28 10:17:15 2022 +0100

    [CMSIS-NN] Moved TFLite model making to common area (#10939)
    
    * [CMSIS-NN] Moved TFLite model making to common area
    
    Change-Id: Ic4dbc1919ff0b481c05daf7e57cf9b055c714c9c
    
    * Fixed lint issues with tensorflow import
    
    Change-Id: I7a520beec9c244e9c790d3e82733c2fb476f7e5e
    
    * Resolved merge conflict with main
    
    Change-Id: Iefe58dd321efae6eae26cd54a31c5923d0f1e32b
    
    * Made TFLite layer creation explicit
    
    Change-Id: I7fbf6a5a2163c1fada49477f86d84f1bc09bd57c
    
    * Lint fix: added a missing docstring
    
    Change-Id: If1fb8bb09c538c04e333ccab65a20cff247a504d
---
 python/tvm/relay/testing/tflite.py               | 161 +++++++++++++++++++++++
 tests/python/contrib/test_cmsisnn/test_conv2d.py |  19 +--
 tests/python/contrib/test_cmsisnn/utils.py       | 131 ------------------
 3 files changed, 172 insertions(+), 139 deletions(-)

diff --git a/python/tvm/relay/testing/tflite.py b/python/tvm/relay/testing/tflite.py
new file mode 100644
index 0000000000..df40130ceb
--- /dev/null
+++ b/python/tvm/relay/testing/tflite.py
@@ -0,0 +1,161 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Common utilities for creating TFLite models"""
+from distutils.version import LooseVersion
+import numpy as np
+import pytest
+import tvm
+
+pytest.importorskip("tflite")
+pytest.importorskip("tensorflow")
+import tflite.Model  # pylint: disable=wrong-import-position
+import tensorflow as tf  # pylint: disable=wrong-import-position
+
+
+class TFLiteModel:
+    """Creates TFLite Model and facilitates reference data generation"""
+
+    def __init__(self, dtype):
+        self.serial_model = None  # This is what TFLite convert() provides
+        self.dtype = dtype  # This is the dtype of graph inputs
+        self.shape_dict = {}
+        self.dtype_dict = {}
+
+    def create_conv2d_single(self, kernel_shape, strides, padding, dilation, activation):
+        """Returns tf.function that creates TFLite Conv2d layer"""
+
+        @tf.function
+        def conv2d_single_function(ifm_tensor):
+            """Returns TFLite Conv2d layer"""
+            op = tf.nn.conv2d(
+                ifm_tensor,
+                filters=tf.constant(
+                    np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
+                    dtype=tf.float32,
+                ),
+                strides=[1, strides[0], strides[1], 1],
+                padding=padding,
+                dilations=dilation,
+            )
+            if activation == "RELU":
+                op = tf.nn.relu(op)
+            elif activation == "NONE":
+                pass
+            else:
+                assert False, "Unsupported activation {}".format(activation)
+            return op
+
+        return conv2d_single_function
+
+    def create_tflite_model(self, tfl_function, shapes, ranges=None):
+        """Creates TFLite serial graph"""
+        tensor_specs = []
+        for i, shape in enumerate(shapes):
+            input_name = "input_" + str(i)
+            self.shape_dict.update({input_name: shape})
+            self.dtype_dict.update({input_name: self.dtype})
+            tensor_specs.append(tf.TensorSpec(shape, dtype=tf.float32, name=input_name))
+        concrete_func = tfl_function.get_concrete_function(*tensor_specs)
+
+        if not ranges:
+            ranges = [(0, 1) for _ in shapes]
+
+        def representative_dataset():
+            for _ in range(100):
+                inputs = []
+                for i, shape in enumerate(shapes):
+                    data = np.random.uniform(
+                        low=ranges[i][0], high=ranges[i][1], size=tuple(shape)
+                    ).astype("float32")
+                    inputs.append(data)
+
+                yield inputs
+
+        converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        self.serial_model = converter.convert()
+
+    def convert_to_relay(self):
+        """Converts TFLite serialized graph into Relay"""
+        assert self.serial_model is not None, "TFLite model is empty!"
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(self.serial_model, 0)
+        relay_module, relay_params = tvm.relay.frontend.from_tflite(
+            tflite_model, self.shape_dict, self.dtype_dict
+        )
+        return relay_module, relay_params
+
+    def generate_randomized_input_data(self, seed, shape, dtype):
+        """Generates randomized input numpy arrays based on shape and dtype."""
+        random_state = np.random.RandomState(seed)
+        random_data = None
+        if dtype == np.float32:
+            random_data = random_state.uniform(-1, 1, size).astype(dtype)
+        else:
+            low = np.iinfo(dtype).min
+            high = np.iinfo(dtype).max + 1
+            random_data = random_state.randint(low, high, shape, dtype)
+        return random_data
+
+    # pylint: disable=import-outside-toplevel
+    def generate_reference_data(self):
+        """
+        This method uses TFLite reference kernels to generate reference output.
+        It returns randomized inputs and reference outputs.
+        """
+        assert self.serial_model is not None, "TFLite model was not created."
+
+        output_tolerance = None
+        if tf.__version__ < LooseVersion("2.5.0"):
+            output_tolerance = 1
+            interpreter = tf.lite.Interpreter(model_content=self.serial_model)
+        else:
+            output_tolerance = 0
+            interpreter = tf.lite.Interpreter(
+                model_content=self.serial_model,
+                experimental_op_resolver_type=tf.lite.experimental.OpResolverType.BUILTIN_REF,
+                experimental_preserve_all_tensors=False,
+            )
+
+        interpreter.allocate_tensors()
+        input_details = interpreter.get_input_details()
+        output_details = interpreter.get_output_details()
+
+        # Generate predictable randomized input
+        seed = 0
+        input_data = {}
+        for input_detail in input_details:
+            input_values = self.generate_randomized_input_data(
+                seed, input_detail["shape"], input_detail["dtype"]
+            )
+            interpreter.set_tensor(input_detail["index"], input_values)
+            input_data.update({input_detail["name"]: input_values})
+
+        interpreter.invoke()
+
+        # Obtain the expected output from interpreter
+        expected_output_data = {}
+        for output_detail in output_details:
+            expected_output_data.update(
+                {output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
+            )
+
+        return input_data, expected_output_data, output_tolerance
diff --git a/tests/python/contrib/test_cmsisnn/test_conv2d.py b/tests/python/contrib/test_cmsisnn/test_conv2d.py
index 6c8f53666e..47245f60e1 100644
--- a/tests/python/contrib/test_cmsisnn/test_conv2d.py
+++ b/tests/python/contrib/test_cmsisnn/test_conv2d.py
@@ -35,14 +35,12 @@ from tests.python.relay.aot.aot_test_utils import (
 from utils import (
     skip_if_no_reference_system,
     make_module,
-    create_conv2d_tflite_relay_models,
     get_range_for_dtype_str,
     get_same_padding,
     get_conv2d_qnn_params,
     make_qnn_relu,
     assert_partitioned_function,
     assert_no_external_function,
-    generate_ref_data_tflite,
 )
 
 
@@ -314,25 +312,30 @@ def test_conv2d_int8_tflite(ifm_shape, kernel_shape, strides, dilation, padding,
     interface_api = "c"
     use_unpacked_api = True
     test_runner = AOT_USMP_CORSTONE300_RUNNER
-
     dtype = "int8"
-    tflite_model, relay_mod, params = create_conv2d_tflite_relay_models(
-        ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
+
+    from tvm.relay.testing.tflite import TFLiteModel
+
+    tfl_model = TFLiteModel(dtype)
+    conv2d_function = tfl_model.create_conv2d_single(
+        kernel_shape, strides, padding, dilation, activation
     )
+    tfl_model.create_tflite_model(conv2d_function, [ifm_shape])
+    relay_mod, relay_params = tfl_model.convert_to_relay()
 
-    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, params)
+    cmsisnn_mod = cmsisnn.partition_for_cmsisnn(relay_mod, relay_params)
 
     # validate pattern matching
     assert_partitioned_function(relay_mod, cmsisnn_mod)
 
     # validate CMSIS-NN output against TFLite output
-    input_map, output_map, output_tolerance = generate_ref_data_tflite(tflite_model)
+    input_map, output_map, output_tolerance = tfl_model.generate_reference_data()
     compile_and_run(
         AOTTestModel(
             module=cmsisnn_mod,
             inputs=input_map,
             outputs=output_map,
-            params=params,
+            params=relay_params,
             output_tolerance=output_tolerance,
         ),
         test_runner,
diff --git a/tests/python/contrib/test_cmsisnn/utils.py b/tests/python/contrib/test_cmsisnn/utils.py
index 6bd375db1f..83c67cd95b 100644
--- a/tests/python/contrib/test_cmsisnn/utils.py
+++ b/tests/python/contrib/test_cmsisnn/utils.py
@@ -225,134 +225,3 @@ def make_qnn_relu(expr, fused_activation_fn, scale, zero_point, dtype):
         )
     if fused_activation_fn == "RELU":
         return tvm.relay.op.clip(expr, a_min=max(qmin, quantize(0.0)), a_max=qmax)
-
-
-def generate_random_input_data(seed, shape, dtype):
-    """
-    Generates randomized input numpy arrays based on shape and dtype
-    """
-    random_state = np.random.RandomState(seed)
-    if dtype == np.float32:
-        return random_state.uniform(-1, 1, size).astype(dtype)
-    else:
-        low = np.iinfo(dtype).min
-        high = np.iinfo(dtype).max + 1
-        return random_state.randint(low, high, shape, dtype)
-
-
-def generate_ref_data_tflite(model):
-    """
-    This method uses TFLite reference kernels to generate reference output.
-    Random input generator is used to get the input data.
-    It returns randomized inputs and reference outputs.
-    """
-    import tensorflow as tf
-    from distutils.version import LooseVersion
-
-    output_tolerance = None
-    if tf.__version__ < LooseVersion("2.5.0"):
-        output_tolerance = 1
-        interpreter = tf.lite.Interpreter(model_content=model)
-    else:
-        from tensorflow.lite.python.interpreter import OpResolverType
-
-        output_tolerance = 0
-        interpreter = tf.lite.Interpreter(
-            model_content=model,
-            experimental_op_resolver_type=OpResolverType.BUILTIN_REF,
-            experimental_preserve_all_tensors=False,
-        )
-
-    interpreter.allocate_tensors()
-    input_details = interpreter.get_input_details()
-    output_details = interpreter.get_output_details()
-
-    # Generate predictable randomized input
-    seed = 0
-    input_data = {}
-    for input_detail in input_details:
-        input_values = generate_random_input_data(
-            seed, input_detail["shape"], input_detail["dtype"]
-        )
-        interpreter.set_tensor(input_detail["index"], input_values)
-        input_data.update({input_detail["name"]: input_values})
-
-    interpreter.invoke()
-
-    # Obtain the expected output from interpreter
-    expected_output_data = {}
-    for output_detail in output_details:
-        expected_output_data.update(
-            {output_detail["name"]: interpreter.get_tensor(output_detail["index"])}
-        )
-
-    return input_data, expected_output_data, output_tolerance
-
-
-def create_conv2d_tflite_model(ifm_shape, kernel_shape, strides, dilation, padding, activation):
-    """This method prepares TFlite graph with a single Conv2d layer"""
-    import tensorflow as tf
-
-    class Model(tf.Module):
-        @tf.function
-        def tf_function(self, x):
-            # Use tf.nn API to create the model
-            tf_strides = [1, strides[0], strides[1], 1]
-            op = tf.nn.conv2d(
-                x,
-                filters=tf.constant(
-                    np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]),
-                    dtype=tf.float32,
-                ),
-                strides=tf_strides,
-                padding=padding,
-                dilations=dilation,
-            )
-            if activation:
-                op = tf.nn.relu(op)
-            return op
-
-    model = Model()
-    concrete_func = model.tf_function.get_concrete_function(
-        tf.TensorSpec(ifm_shape, dtype=tf.float32)
-    )
-
-    def representative_dataset():
-        for _ in range(100):
-            data = np.random.rand(*tuple(ifm_shape))
-            yield [data.astype(np.float32)]
-
-    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
-    converter.optimizations = [tf.lite.Optimize.DEFAULT]
-    converter.representative_dataset = representative_dataset
-    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-    converter.inference_input_type = tf.int8
-    converter.inference_output_type = tf.int8
-    tflite_model = converter.convert()
-    return tflite_model
-
-
-def create_conv2d_tflite_relay_models(
-    ifm_shape, kernel_shape, strides, dilation, padding, activation, dtype
-):
-    """
-    This method creates a conv2d TFLite layer and prepared TFLite model from it.
-    Converts that into the Relay module and params.
-    Returns TFLite model, Relay module and params.
-    """
-    pytest.importorskip("tflite")
-    import tflite.Model
-
-    serialized_tflite_model = create_conv2d_tflite_model(
-        ifm_shape, kernel_shape, strides, dilation, padding, activation
-    )
-
-    tflite_model = tflite.Model.Model.GetRootAsModel(serialized_tflite_model, 0)
-
-    relay_module, params = relay.frontend.from_tflite(
-        tflite_model,
-        shape_dict={"input": ifm_shape},
-        dtype_dict={"input": dtype},
-    )
-
-    return serialized_tflite_model, relay_module, params