You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by si...@apache.org on 2020/07/03 01:14:08 UTC

[incubator-tvm] branch master updated: [TFLite] QNN support for TFLite 2.1.0 quantized models (#5848)

This is an automated email from the ASF dual-hosted git repository.

sijusamuel pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 575a383  [TFLite] QNN support for TFLite 2.1.0 quantized models (#5848)
575a383 is described below

commit 575a3835315a533a19e871ee913f01142befebab
Author: Animesh Jain <an...@umich.edu>
AuthorDate: Thu Jul 2 18:13:56 2020 -0700

    [TFLite] QNN support for TFLite 2.1.0 quantized models (#5848)
    
    * [TFLite] TFLite 2.x parser quantization support.
    
    * Address comments. Fix a bug for depthwise conv
    
    * Added tests for relu, conv, quantize. Address comments.
    
    * Using web-data. Minor refactoring.
    
    * Removing TF hub package
    
    * Trigger CI.
    
    * Handle TFLite input layer naming.
    
    * Addressing reviews.
    
    * Retrigger CI.
---
 python/tvm/relay/frontend/tflite.py          | 195 +++++++++++++----
 src/relay/qnn/op/convolution.cc              |  16 +-
 tests/python/frontend/tflite/test_forward.py | 305 ++++++++++++++++++++++-----
 3 files changed, 419 insertions(+), 97 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 2fc82d7..36221b7 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -244,42 +244,92 @@ class OperatorConverter(object):
             qnn_params = None
             tflite_qnn_params = tensor.Quantization()
             if tflite_qnn_params is not None:
-                scale = float(tflite_qnn_params.ScaleAsNumpy())
-                zero_point = int(tflite_qnn_params.ZeroPointAsNumpy())
+                # TFLite supports both per-tensor and per-axis (aka channel) quantization.  For
+                # per-tensor quantization, scale and zero points are scalar values.  For per-axis
+                # quantization, scale and zero points for the weights are tensors (activations are
+                # per-tensor quantized). However, the TFLite quantization spec puts restrictions on
+                # zero points for per-axis quantization.  Specifically, the zero point is a tensor
+                # but all values are 0. More information can be found here -
+                # https://www.tensorflow.org/lite/performance/quantization_spec
+
+                tflite_scale = tflite_qnn_params.ScaleAsNumpy()
+                tflite_zero_point = tflite_qnn_params.ZeroPointAsNumpy()
+                is_qnn_params_valid = True
+
+                # Handle Per-axis and per-tensor cases
+                if isinstance(tflite_scale, np.ndarray):
+                    assert isinstance(tflite_zero_point, np.ndarray)
+
+                    # Tensor - Per-axis quantization
+                    if tflite_scale.size != 1 and tflite_zero_point.size != 1:
+                        scale = tflite_scale
+                        # Ensure that all zero points are zeros
+                        zero_point = tflite_zero_point
+                        if not np.all(zero_point == 0):
+                            raise tvm.error.OpAttributeInvalid(\
+                                    "TFLite per-axis quantization restricts all zero points to be"
+                                    + " 0, but a non-zero value is observed")
+                        zero_point = int(zero_point[0])
+
+                    # Scalar - Per-tensor quantization
+                    elif tflite_scale.size == 1 and tflite_zero_point.size == 1:
+                        scale = float(tflite_scale[0])
+                        zero_point = int(tflite_zero_point[0])
+
+                    else:
+                        raise NotImplementedError(\
+                                "Quantized type {} (scale) and  {} (zero point) not supported"
+                                .format(type(tflite_scale), type(tflite_zero_point)))
+                elif tflite_scale == 0 and tflite_zero_point == 0:
+                    # Handle corner case for ops like quantized reshape whose second operand (shape)
+                    # has zero scale and zero zero point. This is not used.
+                    is_qnn_params_valid = False
+                else:
+                    raise NotImplementedError("Quantized type {} not supported"
+                                              .format(type(tflite_scale)))
+
                 # Check that the scale and zero points are valid.
-                if scale != 0 or zero_point != 0:
+                if is_qnn_params_valid:
                     qnn_params = dict()
                     qnn_params['scale'] = relay.const(scale, 'float32')
                     qnn_params['zero_point'] = relay.const(zero_point, 'int32')
             return_list.append(TensorWrapper(tensor_idx, tensor, buffer, qnn_params))
         return return_list
 
-    def get_tensor_value(self, tensor_wrapper):
-        """Get tensor buffer value from given tensor wrapper"""
+
+    def get_tensor_type_as_numpy(self, tensor_wrapper):
+        """Returns np.dtype out of TensorType"""
         assert isinstance(tensor_wrapper, TensorWrapper)
 
         try:
             from tflite.TensorType import TensorType
+            return {TensorType.UINT8: np.uint8,
+                    TensorType.INT8: np.int8,
+                    TensorType.FLOAT32: np.float32,
+                    TensorType.INT32: np.int32,
+                    TensorType.INT64: np.int64,
+                    TensorType.BOOL: np.bool_}[tensor_wrapper.tensor.Type()]
         except ImportError:
             raise ImportError("The tflite package must be installed")
+        except KeyError:
+            raise NotImplementedError("Tensor type '{}' currently not supported"
+                                      .format(tensor_wrapper.tensor.Type()))
+
+
+    def get_tensor_value(self, tensor_wrapper):
+        """Get tensor buffer value from given tensor wrapper"""
+        assert isinstance(tensor_wrapper, TensorWrapper)
+
+        dtype = self.get_tensor_type_as_numpy(tensor_wrapper)
+        data = tensor_wrapper.buffer.DataAsNumpy()
+
+        if tensor_wrapper.tensor.ShapeLength() != 0:
+            shape = tensor_wrapper.tensor.ShapeAsNumpy()
+        else:
+            shape = []
+
+        return np.frombuffer(data, dtype=dtype).reshape(shape)
 
-        if tensor_wrapper.tensor.Type() == TensorType.UINT8:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.uint8).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        if tensor_wrapper.tensor.Type() == TensorType.FLOAT32:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.float32).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        if tensor_wrapper.tensor.Type() == TensorType.INT32:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int32).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        if tensor_wrapper.tensor.Type() == TensorType.INT64:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.int64).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        if tensor_wrapper.tensor.Type() == TensorType.BOOL:
-            return np.frombuffer(tensor_wrapper.buffer.DataAsNumpy(), dtype=np.bool_).reshape(
-                tensor_wrapper.tensor.ShapeAsNumpy())
-        raise NotImplementedError("Tensor type {} is currently not supported"
-                                  .format(str(tensor_wrapper.tensor.Type())))
 
     def get_tensor_type_str(self, tensor_type):
         """Get tensor type string representation when given TFLite tensor type"""
@@ -651,12 +701,43 @@ class OperatorConverter(object):
 
     def convert_relu(self, op):
         """Convert TFLite ReLU"""
+        try:
+            from tflite.ActivationFunctionType import ActivationFunctionType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
-
         input_tensor = input_tensors[0]
         in_expr = self.get_expr(input_tensor.tensor_idx)
-        out = _op.nn.relu(in_expr)
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+
+        if input_tensor.qnn_params:
+            # Quantize a float value to an quantized integer value
+            scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
+            zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
+
+            output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=in_expr,
+                    fused_activation_fn=ActivationFunctionType.RELU,
+                    scale=scale_val,
+                    zero_point=zero_point_val,
+                    dtype=output_tensor_type_str)
+        else:
+            out = _op.nn.relu(in_expr)
+
+        if output_tensor.qnn_params:
+            output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+            out = _qnn.op.requantize(out,
+                                     input_scale=input_tensor.qnn_params['scale'],
+                                     input_zero_point=input_tensor.qnn_params['zero_point'],
+                                     output_scale=output_tensor.qnn_params['scale'],
+                                     output_zero_point=output_tensor.qnn_params['zero_point'],
+                                     out_dtype=output_tensor_type_str)
 
         return out
 
@@ -692,6 +773,11 @@ class OperatorConverter(object):
 
     def convert_relu6(self, op):
         """Convert TFLite ReLU6"""
+        try:
+            from tflite.ActivationFunctionType import ActivationFunctionType
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
         input_tensor = input_tensors[0]
@@ -705,17 +791,14 @@ class OperatorConverter(object):
             # Quantize a float value to an quantized integer value
             scale_val = get_scalar_from_constant(input_tensor.qnn_params['scale'])
             zero_point_val = get_scalar_from_constant(input_tensor.qnn_params['zero_point'])
-            quantize = lambda x: float(int(round(x / scale_val)) + zero_point_val)
 
-            # Get min/max of the input dtype. This will be used to ensure that
-            # clip a_min/a_max are not beyond the dtype range.
-            input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
-            qmin = float(tvm.tir.op.min_value(input_tensor_type_str).value)
-            qmax = float(tvm.tir.op.max_value(input_tensor_type_str).value)
-
-            out = _op.clip(in_expr,
-                           a_min=max(qmin, quantize(0)),
-                           a_max=min(qmax, quantize(6.0)))
+            output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
+            out = self.convert_qnn_fused_activation_function(\
+                    expr=in_expr,
+                    fused_activation_fn=ActivationFunctionType.RELU6,
+                    scale=scale_val,
+                    zero_point=zero_point_val,
+                    dtype=output_tensor_type_str)
         else:
             out = _op.clip(in_expr, a_min=0, a_max=6)
 
@@ -1604,9 +1687,9 @@ class OperatorConverter(object):
         fully_connected_options.Init(op_options.Bytes, op_options.Pos)
         fused_activation_fn = fully_connected_options.FusedActivationFunction()
 
-        # weight tensor type should be UINT8 (quantization) or FLOAT32
+        # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
         weight_tensor_type = weight_tensor.tensor.Type()
-        assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
         weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
 
         if self.has_expr(weight_tensor.tensor_idx):
@@ -1795,9 +1878,9 @@ class OperatorConverter(object):
             params['channels'] = int(output_channels)
             params['kernel_layout'] = 'HWIO'
 
-        # weight tensor type should be UINT8 (quantization) or FLOAT32
+        # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
         weight_tensor_type = weight_tensor.tensor.Type()
-        assert weight_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        assert weight_tensor_type in (TensorType.INT8, TensorType.UINT8, TensorType.FLOAT32)
         weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
 
         in_expr = self.get_expr(input_tensor_idx)
@@ -1856,9 +1939,15 @@ class OperatorConverter(object):
         if output_tensor.qnn_params:
             # Calculate the intermediate scale and zero point of the int32 output.
             data_scale = input_tensor.qnn_params['scale']
-            weight_scale = weight_tensor.qnn_params['scale']
             data_scale_val = get_scalar_from_constant(data_scale)
-            weight_scale_val = get_scalar_from_constant(weight_scale)
+
+            weight_scale = weight_tensor.qnn_params['scale']
+            # If weight scale is scalar, it is per-tensor quantization
+            if isinstance(weight_scale, float):
+                weight_scale_val = get_scalar_from_constant(weight_scale)
+            else:
+                weight_scale_val = get_tensor_from_constant(weight_scale)
+
             new_input_scale_val = data_scale_val * weight_scale_val
             new_input_scale = relay.const(new_input_scale_val, 'float32')
             new_input_zero_point = relay.const(0, 'int32')
@@ -1869,7 +1958,8 @@ class OperatorConverter(object):
                                      input_zero_point=new_input_zero_point,
                                      output_scale=output_tensor.qnn_params['scale'],
                                      output_zero_point=output_tensor.qnn_params['zero_point'],
-                                     out_dtype=output_tensor_type_str)
+                                     out_dtype=output_tensor_type_str,
+                                     axis=3)
 
             # Call activation function
             output_scale_val = get_scalar_from_constant(output_tensor.qnn_params['scale'])
@@ -1882,7 +1972,6 @@ class OperatorConverter(object):
                     dtype=output_tensor_type_str)
         else:
             out = self.convert_fused_activation_function(out, fused_activation_fn)
-
         return out
 
     def convert_split(self, op):
@@ -2594,17 +2683,27 @@ class OperatorConverter(object):
         input_tensors = self.get_input_tensors(op)
         assert len(input_tensors) == 1, "input tensors length should be 1"
         input_tensor = input_tensors[0]
+        input_tensor_type_str = self.get_tensor_type_str(input_tensor.tensor.Type())
         in_expr = self.get_expr(input_tensor.tensor_idx)
 
         output_tensors = self.get_output_tensors(op)
         assert len(output_tensors) == 1, "output tensors length should be 1"
         output_tensor = output_tensors[0]
+        output_tensor_type_str = self.get_tensor_type_str(output_tensor.tensor.Type())
 
         # The output must be quantized
         assert output_tensor.qnn_params
-        # Quantize the input
-        out = self.quantize(in_expr, output_tensor)
 
+        # TFLite Quantize op can also act as Requantize op
+        if input_tensor_type_str == "float32":
+            out = self.quantize(in_expr, output_tensor)
+        else:
+            out = _qnn.op.requantize(in_expr,
+                                     input_scale=input_tensor.qnn_params['scale'],
+                                     input_zero_point=input_tensor.qnn_params['zero_point'],
+                                     output_scale=output_tensor.qnn_params['scale'],
+                                     output_zero_point=output_tensor.qnn_params['zero_point'],
+                                     out_dtype=output_tensor_type_str)
         return out
 
     def convert_dequantize(self, op):
@@ -2734,7 +2833,6 @@ class OperatorConverter(object):
         else:
             type_str = self.get_tensor_type_str(tensor.tensor.Type())
             expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)
-
         return expr
 
 
@@ -2747,6 +2845,13 @@ def get_scalar_from_constant(expr):
         "value must be float32/int32"
     return np.asscalar(value)
 
+def get_tensor_from_constant(expr):
+    """ Returns tensor of values from Relay constant node. """
+    assert isinstance(expr, _expr.Constant)
+    value = expr.data.asnumpy()
+    assert value.dtype == np.dtype(np.int32) or value.dtype == np.dtype(np.float32), \
+        "value must be float32/int32"
+    return value
 
 def build_str_map(obj):
     """Build string map of TFLite enum int value
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index 9412ab4..5d2e360 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -61,9 +61,19 @@ bool QnnConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   CHECK(IsScalarType(types[3], DataType::Int(32)));    // kernel_zero_point
   CHECK(IsScalarType(types[4], DataType::Float(32)));  // input_scale
   // Kernel scale can be a vector of length output_channels or a scalar.
-  size_t axis = param->kernel_layout.find('O');
-  CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
-  AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter);  // kernel scale
+  if (param->groups == 1) {
+    size_t axis = param->kernel_layout.find('O');
+    CHECK(axis != std::string::npos) << "Kernel layout attribute is not defined";
+    AssignType(types[5], DataType::Float(32), weight->shape[axis], reporter);  // kernel scale
+  } else {
+    // Here, total number of output channels depend on depth multiplier.
+    size_t o_axis = param->kernel_layout.find('O');
+    size_t i_axis = param->kernel_layout.find('I');
+    CHECK(o_axis != std::string::npos || i_axis != std::string::npos)
+        << "Kernel layout attribute is not defined";
+    AssignType(types[5], DataType::Float(32), weight->shape[i_axis] * weight->shape[o_axis],
+               reporter);  // kernel scale
+  }
 
   // Collect the input tensor and output tensor devoid of scale and zero points to reuse Relay
   // Conv2D infer type function.
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index 166eb27..52491b2 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -73,6 +73,25 @@ def get_real_image(im_height, im_width):
     data = np.reshape(x, (1, im_height, im_width, 3))
     return data
 
+
+def pre_processed_image(height, width):
+    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
+    img_name = 'elephant-299.jpg'
+    image_url = os.path.join(repo_base, img_name)
+    img_path = download_testdata(image_url, img_name, module='data')
+    image = tf.io.read_file(img_path)
+    image = tf.image.decode_jpeg(image, channels=3)
+    with tf.name_scope('eval_image'):
+        if image.dtype != tf.float32:
+            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+        image = tf.image.central_crop(image, central_fraction=0.875)
+    # Resize the image to the specified height and width.
+    image = tf.image.resize(image, [height, width],
+                            align_corners=False)
+    image = tf.expand_dims(image, axis=0)
+    return image
+
+
 def get_real_image_object_detection(im_height, im_width):
     repo_base = 'https://github.com/dmlc/web-data/raw/master/gluoncv/detection/'
     img_name = 'street_small.jpg'
@@ -109,6 +128,18 @@ def vmobj_to_list(o):
     else:
         raise RuntimeError("Unknown object type: %s" % type(o))
 
+
+def _quantize_keras_model(keras_model, representative_data_gen):
+    """Utility function to quantize a Keras model using TFLite converter."""
+    converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
+    converter.representative_dataset = representative_data_gen
+    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+    converter.inference_input_type = tf.uint8
+    converter.inference_output_type = tf.uint8
+    return converter.convert()
+
+
 def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
                   out_names=None, mode='graph_runtime'):
     """ Generic function to compile on relay and execute on tvm """
@@ -717,6 +748,70 @@ def test_forward_l2_pool2d():
 # Convolution
 # -----------
 
+
+def _test_tflite2_quantized_convolution(input_shape, kernel_shape,
+        dilations, strides, padding, data_format):
+    """ One iteration of TFLite2 quantized convolution with given shapes and attributes """
+    data_format = "channels_last" if "NHWC" else "channels_first"
+    data = np.random.uniform(0, 1, input_shape).astype('float32')
+    kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+
+    data_in = tf.keras.layers.Input(shape=data.shape[1:])
+    conv = tf.keras.layers.Conv2D(filters=kernel_shape[3],
+                                  kernel_size=(kernel_shape[0], kernel_shape[1]),
+                                  strides=strides,
+                                  padding=padding,
+                                  data_format=data_format,
+                                  activation='relu',
+                                  use_bias=False)(data_in)
+    keras_model = tf.keras.models.Model(data_in, conv)
+    keras_model.layers[1].set_weights([kernel])
+
+    # To create quantized values with dynamic range of activations, needs representative dataset
+    def representative_data_gen():
+        for i in range(1):
+            yield [data]
+
+    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+    tflite_output = run_tflite_graph(tflite_model_quant, data)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
+    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                rtol=1e-2, atol=1e-2)
+
+
+def _test_tflite2_quantized_depthwise_convolution(input_shape, kernel_shape,
+        dilations, strides, padding, data_format, depth_multiplier):
+    """One iteration of TFLite2 quantized depthwise convolution with given shapes and attributes"""
+    data_format = "channels_last" if "NHWC" else "channels_first"
+    data = np.random.uniform(0, 1, input_shape).astype('float32')
+    kernel = np.random.uniform(0, 1, kernel_shape).astype('float32')
+
+    data_in = tf.keras.layers.Input(shape=data.shape[1:])
+    conv = tf.keras.layers.DepthwiseConv2D(kernel_size=(kernel_shape[0], kernel_shape[1]),
+                                           strides=strides,
+                                           padding=padding,
+                                           data_format=data_format,
+                                           activation='relu',
+                                           use_bias=False,
+                                           depth_multiplier=depth_multiplier)(data_in)
+    keras_model = tf.keras.models.Model(data_in, conv)
+    keras_model.layers[1].set_weights([kernel])
+
+
+    # To create quantized values with dynamic range of activations, needs representative dataset
+    def representative_data_gen():
+        for i in range(1):
+            yield [data]
+
+    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+    tflite_output = run_tflite_graph(tflite_model_quant, data)
+    tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0",""))
+    tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                rtol=1e-2, atol=1e-2)
+
+
 def _test_convolution(tensor_in_sizes, filter_in_sizes,
                       dilations, strides, padding, data_format,
                       is_depthwise=False, quantized=False):
@@ -757,24 +852,38 @@ def _test_convolution(tensor_in_sizes, filter_in_sizes,
                                 data_format=data_format)
 
         if quantized:
-            # For now only quantized conv2d is supported
-            assert not is_depthwise
-
-            # Quantized the inputs and feed them to the convolution
-            inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
-            inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
-            out = nn_ops.conv2d(inq_data,
-                                inq_filter,
-                                strides=strides,
-                                padding=padding,
-                                data_format=data_format)
-            out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
-
-            # Set the input quantization range
-            input_range = {'in_data': (-100, 100)} if quantized else None
-
-            # Compare
-            compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+            if is_depthwise:
+                # Quantized the inputs and feed them to the convolution
+                inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
+                inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
+                out = nn_ops.depthwise_conv2d_native(inq_data,
+                                                     inq_filter,
+                                                     strides=strides,
+                                                     padding=padding,
+                                                     data_format=data_format)
+                out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+
+                # Set the input quantization range
+                input_range = {'in_data': (-100, 100)} if quantized else None
+
+                # Compare
+                compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
+            else:
+                # Quantized the inputs and feed them to the convolution
+                inq_data = tf.quantization.fake_quant_with_min_max_args(in_data, min=-100, max=100, name='inq_data')
+                inq_filter = tf.quantization.fake_quant_with_min_max_args(in_filter, min=-100, max=100, name='inq_filter')
+                out = nn_ops.conv2d(inq_data,
+                                    inq_filter,
+                                    strides=strides,
+                                    padding=padding,
+                                    data_format=data_format)
+                out = tf.quantization.fake_quant_with_min_max_args(out, min=-200, max=200, name="out")
+
+                # Set the input quantization range
+                input_range = {'in_data': (-100, 100)} if quantized else None
+
+                # Compare
+                compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out], quantized=quantized, input_range=input_range)
         else:
             data_array = np.reshape(data_array, tensor_in_sizes).astype('float32')
             compare_tflite_with_tvm(data_array, 'in_data', [in_data], [out])
@@ -787,14 +896,30 @@ def test_forward_convolution():
         _test_convolution([4, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC', quantized=quantized)
         _test_convolution([4, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC', quantized=quantized)
 
-    # depthwise convolution
-    _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
-    _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
-    _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True)
-    _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True)
-    _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True)
-    # dephtwise convolution with single input channel
-    _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True)
+        # depthwise convolution
+        _test_convolution([4, 8, 8, 176], [1, 1, 176, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+        _test_convolution([4, 17, 17, 19], [3, 3, 19, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+        _test_convolution([4, 17, 17, 124], [1, 1, 124, 1], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+        _test_convolution([4, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+        _test_convolution([4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC', True, quantized=quantized)
+        # depthwise convolution with single input channel
+        _test_convolution([1, 76, 64, 1], [9, 5, 1, 96], [1, 1], [1, 1], 'SAME', 'NHWC', True, quantized=quantized)
+
+    # TFLite2 quantized convolution testing
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+        _test_tflite2_quantized_convolution([1, 8, 8, 176], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME', 'NHWC')
+        _test_tflite2_quantized_convolution([1, 17, 17, 12], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID', 'NHWC')
+        _test_tflite2_quantized_convolution([1, 17, 17, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID', 'NHWC')
+        _test_tflite2_quantized_convolution([1, 17, 17, 124], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME', 'NHWC')
+
+        # depthwise convolution
+        _test_tflite2_quantized_depthwise_convolution([1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1],
+                                                      'SAME', 'NHWC', 1)
+        _test_tflite2_quantized_depthwise_convolution([1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2],
+                                                      'VALID', 'NHWC', 1)
+        _test_tflite2_quantized_depthwise_convolution([1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2],
+                                                      'SAME', 'NHWC', 8)
+
 
 
 #######################################################################
@@ -1686,38 +1811,33 @@ def test_forward_squeeze():
 def _test_quantize_dequantize(data):
     """ One iteration of quantize and dequantize """
 
-    # Define a dummy model
+    # Keras model to force TFLite converter to insert 2 TFLite quantize ops.
+    # First TFLite quantize op converts float32 tensor to int8 tensor - Qnn quantize.
+    # Second TFLite quantize op converts int8 tensor to int8 tensor - Qnn requantize.
     data_in = tf.keras.layers.Input(shape=data.shape[1:])
-    act_func =  tf.keras.layers.Activation('linear')
-    keras_model = tf.keras.models.Model(data_in, act_func(data_in))
-
-    # Load the model
-    converter = interpreter_wrapper.TFLiteConverter.from_keras_model(keras_model)
+    relu = tf.keras.layers.ReLU()(data_in)
+    add = tf.keras.layers.Add()([data_in, relu])
+    concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
+    keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
+    input_name = data_in.name.split(":")[0]
 
     # To create quantized values with dynamic range of activations, needs representative dataset
     def representative_data_gen():
-        for i in range(100):
+        for i in range(1):
             yield [data]
 
-    converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
-    converter.representative_dataset = representative_data_gen
-    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
-    converter.inference_input_type = tf.uint8
-    converter.inference_output_type = tf.uint8
-
-    # Convert the model to TensorFlow Lite format
-    tflite_model_quant = converter.convert()
+    tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
 
     tflite_output = run_tflite_graph(tflite_model_quant, data)
-    tvm_output = run_tvm_graph(tflite_model_quant, data, 'input_1')
+    tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
     tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
-                                rtol=1e-5, atol=1e-5)
+                                rtol=1e-5, atol=1e-2)
 
 
 def test_forward_quantize_dequantize():
     """ Quantize Dequantize """
     data = np.random.uniform(0, 1, (1, 4, 4, 3)).astype("float32")
-    if package_version.parse(tf.VERSION) >= package_version.parse('2.0.0'):
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
         _test_quantize_dequantize(data)
 
 
@@ -1945,16 +2065,38 @@ def test_forward_tanh():
 # ReLu
 # ----
 
-def _test_relu(data):
+def _test_relu(data, quantized=False):
     """ One iteration of ReLU """
-    with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
-        out = nn_ops.relu(in_data)
-        compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
+
+    if quantized:
+        if package_version.parse(tf.VERSION) < package_version.parse('2.1.0'):
+            pytest.skip("Testcase requires tflite version >= 2.1.0")
+        data_in = tf.keras.layers.Input(shape=data.shape[1:])
+        relu =  tf.keras.layers.ReLU()(data_in)
+        keras_model = tf.keras.models.Model(inputs=data_in, outputs=relu)
+        input_name = data_in.name.split(":")[0]
+
+        # To create quantized values with dynamic range of activations, needs representative dataset
+        def representative_data_gen():
+            for i in range(1):
+                yield [data]
+
+        tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen)
+
+        tflite_output = run_tflite_graph(tflite_model_quant, data)
+        tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
+        tvm.testing.assert_allclose(np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]),
+                                    rtol=1e-5, atol=1e-5)
+    else:
+        with tf.Graph().as_default():
+            in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+            out = nn_ops.relu(in_data)
+            compare_tflite_with_tvm(data, 'Placeholder:0', [in_data], [out])
 
 def test_forward_relu():
     """ ReLU """
     _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)))
+    _test_relu(np.arange(6.0, dtype=np.float32).reshape((1, 6)), quantized=True)
 
 #######################################################################
 # ReLU6
@@ -2471,6 +2613,66 @@ def test_forward_qnn_mobilenet_v3_net():
     tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
 
 
+def test_forward_tflite2_qnn_resnet50():
+    """Test the Quantized TFLite version 2.1.0 Resnet50 model."""
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+        tflite_model_file = download_testdata(
+            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/resnet_50_quantized.tflite",
+            "resnet_50_quantized.tflite")
+        with open(tflite_model_file, "rb") as f:
+            tflite_model_buf = f.read()
+
+        data = pre_processed_image(224, 224)
+
+        tflite_output = run_tflite_graph(tflite_model_buf, data)
+        tflite_predictions = np.squeeze(tflite_output)
+        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_predictions = np.squeeze(tvm_output)
+        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
+def test_forward_tflite2_qnn_inception_v1():
+    """Test the Quantized TFLite version 2.1.0 Inception V1 model."""
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+        tflite_model_file = download_testdata(
+            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/inception_v1_quantized.tflite",
+            "inception_v1_quantized.tflite")
+        with open(tflite_model_file, "rb") as f:
+            tflite_model_buf = f.read()
+
+        data = pre_processed_image(224, 224)
+
+        tflite_output = run_tflite_graph(tflite_model_buf, data)
+        tflite_predictions = np.squeeze(tflite_output)
+        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_predictions = np.squeeze(tvm_output)
+        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
+def test_forward_tflite2_qnn_mobilenet_v2():
+    """Test the Quantized TFLite version 2.1.0 Mobilenet V2 model."""
+    if package_version.parse(tf.VERSION) >= package_version.parse('2.1.0'):
+        tflite_model_file = download_testdata(
+            "https://raw.githubusercontent.com/dmlc/web-data/master/tensorflow/models/Quantized/mobilenet_v2_quantized.tflite",
+            "mobilenet_v2_quantized.tflite")
+        with open(tflite_model_file, "rb") as f:
+            tflite_model_buf = f.read()
+
+        data = pre_processed_image(224, 224)
+
+        tflite_output = run_tflite_graph(tflite_model_buf, data)
+        tflite_predictions = np.squeeze(tflite_output)
+        tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+        tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), 'input_1')
+        tvm_predictions = np.squeeze(tvm_output)
+        tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]
+        tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels)
+
+
 #######################################################################
 # Quantized SSD Mobilenet
 # -----------------------
@@ -2689,3 +2891,8 @@ if __name__ == '__main__':
     #with Tflite 1.15.2
     test_forward_qnn_mobilenet_v3_net()
     test_forward_qnn_coco_ssd_mobilenet_v1()
+
+    # TFLite 2.1.0 quantized tests
+    test_forward_tflite2_qnn_resnet50()
+    test_forward_tflite2_qnn_inception_v1()
+    test_forward_tflite2_qnn_mobilenet_v2()