You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/12/21 13:49:21 UTC

[tvm] branch main updated: [TFLite] add support for float16 (#7093)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9914685  [TFLite] add support for float16 (#7093)
9914685 is described below

commit 991468502f3a629560a4c284b73ce52094573523
Author: eric <eu...@samsung.com>
AuthorDate: Mon Dec 21 22:49:11 2020 +0900

    [TFLite] add support for float16 (#7093)
    
    * [TFLite] add support for float16
    
    * add testi case
    
    * add test case
    
    * add comments
---
 python/tvm/relay/frontend/tflite.py          | 61 +++++++++++++++++++++-------
 tests/python/frontend/tflite/test_forward.py | 35 +++++++++++++++-
 2 files changed, 79 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 54eeb9d..a55eb16 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -325,6 +325,7 @@ class OperatorConverter(object):
             return {
                 TensorType.UINT8: np.uint8,
                 TensorType.INT8: np.int8,
+                TensorType.FLOAT16: np.float16,
                 TensorType.FLOAT32: np.float32,
                 TensorType.INT32: np.int32,
                 TensorType.INT64: np.int64,
@@ -362,6 +363,8 @@ class OperatorConverter(object):
             return "int8"
         if tensor_type == TensorType.UINT8:
             return "uint8"
+        if tensor_type == TensorType.FLOAT16:
+            return "float16"
         if tensor_type == TensorType.FLOAT32:
             return "float32"
         if tensor_type == TensorType.INT32:
@@ -1991,20 +1994,33 @@ class OperatorConverter(object):
         weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
 
         in_expr = self.get_expr(input_tensor_idx)
-        weight_value = self.get_tensor_value(weight_tensor)
-
-        # TFLite kernel layout:
-        # convolution:
-        # OC KH KW IC, we require KH KW IC OC (HWIO)
-        # depthwise convolution:
-        # 1 KH KW C(input_c * depth_multiplier), we require
-        # KH KW IC M (depth_multiplier) (HWOI)
-        if is_depthwise_conv:
-            weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
+
+        # TFLite converts float32 models to float16 models by introducing
+        # a Dequantize op in every op that contains a float32 values.
+        # (weights, biases, and constants etc. )
+        # So conv op may have weight and bias as tensors instead of values.
+        if self.has_expr(weight_tensor.tensor_idx):
+            weight_expr = self.get_expr(weight_tensor.tensor_idx)
+            if is_depthwise_conv:
+                weight_expr = _op.reshape(
+                    weight_expr, (kernel_h, kernel_w, input_c, depth_multiplier)
+                )
+            else:
+                weight_expr = _op.transpose(weight_expr, axes=(1, 2, 3, 0))
         else:
-            weight_value = weight_value.transpose((1, 2, 3, 0))
+            weight_value = self.get_tensor_value(weight_tensor)
+            # TFLite kernel layout:
+            # convolution:
+            # OC KH KW IC, we require KH KW IC OC (HWIO)
+            # depthwise convolution:
+            # 1 KH KW C(input_c * depth_multiplier), we require
+            # KH KW IC M (depth_multiplier) (HWOI)
+            if is_depthwise_conv:
+                weight_value = weight_value.reshape(kernel_h, kernel_w, input_c, depth_multiplier)
+            else:
+                weight_value = weight_value.transpose((1, 2, 3, 0))
 
-        weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
+            weight_expr = self.exp_tab.new_const(weight_value, dtype=weight_tensor_type_str)
 
         if padding == Padding.VALID:
             pass
@@ -2039,9 +2055,12 @@ class OperatorConverter(object):
             # bias tensor type should be INT32 (quantization) or FLOAT32
             assert bias_tensor_type in (TensorType.INT32, TensorType.FLOAT32)
             bias_tensor_type_str = self.get_tensor_type_str(bias_tensor_type)
-            bias_expr = self.exp_tab.new_const(
-                self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
-            )
+            if self.has_expr(bias_tensor.tensor_idx):
+                bias_expr = self.get_expr(bias_tensor.tensor_idx)
+            else:
+                bias_expr = self.exp_tab.new_const(
+                    self.get_tensor_value(bias_tensor), dtype=bias_tensor_type_str
+                )
             channel_axis = 3
             out = _op.nn.bias_add(out, bias_expr, axis=channel_axis)
 
@@ -2870,10 +2889,22 @@ class OperatorConverter(object):
 
     def convert_dequantize(self, op):
         """Convert TFLite Dequantize"""
+        try:
+            from tflite.TensorType import TensorType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
 
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
         input_tensor = input_tensors[0]
+
+        if input_tensor.tensor.Type() == TensorType.FLOAT16:
+            dtype = self.get_tensor_type_str(input_tensor.tensor.Type())
+            input_value = self.get_tensor_value(input_tensor)
+            in_expr = self.exp_tab.new_const(input_value, dtype=dtype)
+            out = relay.cast(in_expr, dtype="float32")
+            return out
+
         in_expr = self.get_expr(input_tensor.tensor_idx)
 
         # The input must be quantized
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 56c50c3..7675768 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -71,13 +71,13 @@ def convert_to_list(x):
 #######################################################################
 # Get a real image for e2e testing
 # --------------------------------
-def get_real_image(im_height, im_width):
+def get_real_image(im_height, im_width, quantized=True):
     repo_base = "https://github.com/dmlc/web-data/raw/main/tensorflow/models/InceptionV1/"
     img_name = "elephant-299.jpg"
     image_url = os.path.join(repo_base, img_name)
     img_path = download_testdata(image_url, img_name, module="data")
     image = Image.open(img_path).resize((im_height, im_width))
-    x = np.array(image).astype("uint8")
+    x = np.array(image).astype("uint8") if quantized else np.array(image).astype("float32")
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
@@ -3792,6 +3792,35 @@ def test_forward_tflite2_qnn_mobilenet_v2():
         tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 
+def test_forward_tflite_float16():
+    """Test float16 quantized model"""
+    # MobilenetV2
+    tflite_model_file = tf_testing.get_workload_official(
+        "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz",
+        "mobilenet_v1_0.25_128_frozen.pb",
+    )
+
+    converter = tf.lite.TFLiteConverter.from_frozen_graph(
+        tflite_model_file, ["input"], ["MobilenetV1/Predictions/Reshape_1"]
+    )
+    converter.optimizations = [tf.lite.Optimize.DEFAULT]
+    converter.target_spec.supported_types = [tf.float16]
+    tflite_model_buf = converter.convert()
+
+    # Test image. Checking the labels because the requantize implementation is different between
+    # TFLite and Relay. This cause final output numbers to mismatch. So, testing accuracy via
+    # labels. Also, giving a real image, instead of random inputs.
+    data = get_real_image(128, 128, quantized=False)
+
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tflite_predictions = np.squeeze(tflite_output)
+    tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "input")
+    tvm_predictions = np.squeeze(tvm_output)
+    tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+    tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
 #######################################################################
 # Quantized SSD Mobilenet
 # -----------------------
@@ -4057,3 +4086,5 @@ if __name__ == "__main__":
     test_forward_tflite2_qnn_resnet50()
     test_forward_tflite2_qnn_inception_v1()
     test_forward_tflite2_qnn_mobilenet_v2()
+
+    test_forward_tflite_float16()