You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/06/25 05:40:13 UTC

[GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5848: [TFLite] QNN support for TFLite 2.1.0 quantized models

siju-samuel commented on a change in pull request #5848:
URL: https://github.com/apache/incubator-tvm/pull/5848#discussion_r445310994



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -262,21 +298,25 @@ def get_tensor_value(self, tensor_wrapper):
         except ImportError:
             raise ImportError("The tflite package must be installed")
 
+        data = tensor_wrapper.buffer.DataAsNumpy()
+        shape = tensor_wrapper.tensor.ShapeAsNumpy()
+
+        # Set shape to 1 if the data is a scalar type
+        if data.shape == (1,) and isinstance(shape, int) and shape == 0:
+            shape = (1,)
+
+        if tensor_wrapper.tensor.Type() == TensorType.INT8:
+            return np.frombuffer(data, dtype=np.int8).reshape(shape)
         if tensor_wrapper.tensor.Type() == TensorType.UINT8:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
+            return np.frombuffer(data, dtype=np.uint8).reshape(shape)

Review comment:
       Can we add support for int16/float16 as well?

##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -2445,6 +2467,112 @@ def test_forward_qnn_mobilenet_v3_net():
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 
+
+def _quantize_tf_hub_keras_model(url, height, width):
+    keras_model = tf.keras.Sequential([hub.KerasLayer(url, output_shape=[1001])])
+    data = pre_processed_image(height, width)
+
+    # Set the input shapes of the keras model
+    keras_model._set_inputs(data)
+
+    # Get the converter
+    converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+
+    # To create quantized values with dynamic range of activations, needs representative dataset
+    def representative_data_gen():
+        for i in range(1):
+            yield [data]
+
+    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+    converter.representative_dataset = representative_data_gen
+    return converter.convert()
+
+
+def test_forward_tflite2_qnn_resnet50():
+    """Test the Quantized TFLite version 2.1.0 Resnet50 model."""
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+        # Quantize the model
+        url = "https://tfhub.dev/tensorflow/resnet_50/classification/1"
+        tflite_model_buf = _quantize_tf_hub_keras_model(url, 224, 224)
+        data = pre_processed_image(224, 224)
+
+        tflite_output = run_tflite_graph(tflite_model_buf, data)
+        tflite_predictions = np.squeeze(tflite_output)
+        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_predictions = np.squeeze(tvm_output)
+        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)

Review comment:
       just for info, How different is the output values? if comparable, we can do the output value comparison itself. Since rounding method is different in tflite and tvm and due to that there are some difference in  outputs.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1605,7 +1645,7 @@ def convert_fully_connected(self, op):
 
         # weight tensor type should be UINT8 (quantization) or FLOAT32
         weight_tensor_type = weight_tensor.tensor.Type()
-        assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)

Review comment:
       change the comment in line 1646 as well to include INT8. check similar in other places also

##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -2445,6 +2467,112 @@ def test_forward_qnn_mobilenet_v3_net():
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 
+
+def _quantize_tf_hub_keras_model(url, height, width):

Review comment:
       Add func headers similar to previous testcases.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2723,6 +2779,15 @@ def get_scalar_from_constant(expr):
         "value must be float32/int32"
     return np.asscalar(value)
 
+def get_vector_from_constant(expr):
+    """ Returns scalar value from Relay constant scalar. """

Review comment:
       Update docstring

##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -73,6 +74,28 @@ def get_real_image(im_height, im_width):
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+
+def pre_processed_image(height, width):
+    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+    img_name = 'elephant-299.jpg'
+    image_url = os.path.join(repo_base, img_name)
+    img_path = download_testdata(image_url, img_name, module='data')
+    image = tf.io.read_file(img_path)
+    image = tf.image.decode_jpeg(image, channels=3)
+    with tf.name_scope('eval_image'):
+        if image.dtype != tf.float32:
+            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+        image = tf.image.central_crop(image, central_fraction=0.875)
+    # Resize the image to the specified height and width.
+    image = tf.expand_dims(image, 0)
+    image = tf.image.resize(image, [height, width],
+                            align_corners=False)
+    image = tf.image.resize(image, [height, width])
+    image = tf.squeeze(image, [0])
+    image = tf.expand_dims(image, axis=0)

Review comment:
        i think this squeeze and expanddims are no reqiuired since both are done at axis 0. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org