You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/21 23:03:55 UTC

[GitHub] [incubator-tvm] jainris opened a new pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

jainris opened a new pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523


   * Added dilation_value attribute to dilate operator of Relay/TOPI.
     (Enables custom value for dilation, instead of always 0)
   * Added tests for dilation_value of dilate operator in Relay and TOPI.
   * Added support for quantized input in TRANSPOSE_CONV operator of TFLite.
   * Added tests for quantized input in TRANSPOSE_CONV operator of TFLite.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris closed pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris closed pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-698463367


   Quantized Transpose Convolution code needs some changes, so bringing `dilate `operator changes independently in #6550.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-696425194


   cc @anijain2305 @mbaret @u99127 @FrozenGene @tqchen


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492707575



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       I suppose I was less interested in a proof, and more just a statement of what manipulations are happening.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-696425194


   cc @anijain2305 @mbaret @u99127 @FrozenGene @tqchen


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492592531



##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
 /*! \brief Attributes used in dilate operator */
 struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
   Array<IndexExpr> strides;
+  double dilation_value;

Review comment:
       This is parallel to `pad_value` of `PadAttrs` which is `double`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492621816



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       This is essentially the same as the implementation of Relay op `conv2d_transpose`, which makes the same transformations (with 0 instead of zero-point) to the `input `and `kernel `and then does a convolution. So, the mathematical approach would rather be proving that the transformations followed by convolution is the same as transpose convolution, proving which might take some unreasonable space.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r493558343



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -1114,53 +1116,124 @@ def _test_transpose_conv(tensor_in_sizes, filter_in_sizes, output_shape, strides
         total_size_1 *= s
     for s in filter_in_sizes:
         total_size_2 *= s
-    # Initializes the input tensor with array containing incrementing
-    # numbers from 1.
-    data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
-    filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
 
     with tf.Graph().as_default():
-        in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32")
-        in_filter = constant_op.constant(filter_array, shape=filter_in_sizes, dtype="float32")
-        strides = [1] + strides + [1]
-        # in_filter layout is HWOI
-        out = nn_ops.conv2d_transpose(
-            in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding
-        )
-        data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
-        compare_tflite_with_tvm(data_array, "Placeholder:0", [in_data], [out])
+        if quantized:
+            # Initializes the input tensor with array containing incrementing
+            # numbers from 1.
+            data_array = [max(f, 255) for f in range(1, total_size_1 + 1)]
+            filter_array = [max(f, 255) for f in range(1, total_size_2 + 1)]
+            data_array = np.reshape(data_array, tensor_in_sizes).astype("uint8")
+            filter_array = np.reshape(filter_array, filter_in_sizes).astype("uint8")
+
+            in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data")
+            inq_data = tf.quantization.fake_quant_with_min_max_args(
+                in_data, min=-100, max=100, name="q_data"
+            )
+            input_range = {"q_data": (-100, 100)}
+
+            in_filter = constant_op.constant(
+                filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter"
+            )
+            inq_filter = tf.quantization.fake_quant_with_min_max_args(
+                in_filter, min=-100, max=100, name="q_filter"
+            )
+
+            strides = [1] + strides + [1]
+
+            out = nn_ops.conv2d_transpose(
+                inq_data, inq_filter, output_shape=output_shape, strides=strides, padding=padding
+            )
+            out = tf.quantization.fake_quant_with_min_max_args(out, min=-100, max=100, name="out")
+            compare_tflite_with_tvm(
+                [data_array], ["q_data"], [inq_data], [out], quantized=True, input_range=input_range
+            )
+        else:
+            # Initializes the input tensor with array containing incrementing
+            # numbers from 1.
+            data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
+            filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
+
+            in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype="float32", name="in_data")
+            in_filter = constant_op.constant(
+                filter_array, shape=filter_in_sizes, dtype="float32", name="in_filter"
+            )
+            strides = [1] + strides + [1]
+            # in_filter layout is HWOI
+            out = nn_ops.conv2d_transpose(
+                in_data, in_filter, output_shape=output_shape, strides=strides, padding=padding
+            )
+            data_array = np.reshape(data_array, tensor_in_sizes).astype("float32")
+            compare_tflite_with_tvm([data_array], ["in_data"], [in_data], [out])
 
 
 def test_forward_transpose_conv():
-    # kernel 3x3, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 34, 34, 5], [1, 1], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 65, 5], [2, 2], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 65, 34, 5], [2, 1], "VALID")
-
-    # kernel 3x3, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [3, 3, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [3, 3, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
-
-    # kernel 2x2, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 33, 33, 5], [1, 1], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 33, 5], [2, 1], "VALID")
-
-    # kernel 2x2, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [2, 2, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 64, 5], [2, 2], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [2, 2, 5, 16], [1, 64, 32, 5], [2, 1], "SAME")
-
-    # kernel 1x1, padding VALID
-    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "VALID")
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "VALID")
-
-    # kernel 1x1, padding SAME
-    _test_transpose_conv([4, 32, 32, 16], [1, 1, 5, 16], [4, 32, 32, 5], [1, 1], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 63, 5], [2, 2], "SAME")
-    _test_transpose_conv([1, 32, 32, 16], [1, 1, 5, 16], [1, 63, 32, 5], [2, 1], "SAME")
+    for quantized in [False, True]:

Review comment:
       Would be good to check asym kernels here (eg. 3x2) to see if they work, and if not add an assert.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492592531



##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
 /*! \brief Attributes used in dilate operator */
 struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
   Array<IndexExpr> strides;
+  double dilation_value;

Review comment:
       This is parallel to `pad_value` of `PadAttrs` which is `double`.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       This is essentially the same as the implementation of Relay op `conv2d_transpose`, which makes the same transformations (with 0 instead of zero-point) to the `input `and `kernel `and then does a convolution. So, the mathematical approach would rather be proving that the transformations followed by convolution is the same as transpose convolution, proving which might take some unreasonable space.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492584474



##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
 /*! \brief Attributes used in dilate operator */
 struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
   Array<IndexExpr> strides;
+  double dilation_value;

Review comment:
       Why double vs float?

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2809,7 +2809,7 @@ def convert_transpose_conv(self, op):
         # Weights
         weights_tensor_type = weights_tensor.tensor.Type()
         # weights tensor type should be UINT8 (quantization) or FLOAT32

Review comment:
       Update this comment to include INT8

##########
File path: python/tvm/topi/nn/dilate.py
##########
@@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"):
     strides : list / tuple of n ints
         Dilation stride on each dimension, 1 means no dilation.
 
+    dilation_value : int/float

Review comment:
       document 'optional'

##########
File path: python/tvm/topi/testing/dilate_python.py
##########
@@ -30,6 +30,9 @@ def dilate_python(input_np, strides):
     strides : list / tuple of n ints
         Dilation stride on each dimension, 1 means no dilation.
 
+    dilation_value : int/float

Review comment:
       document 'optional'

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       May be useful to document the mathematical approach taken here.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       I suppose I was less interested in a proof, and more just a statement of what manipulations are happening.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
mbaret commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-697350073


   also ping @siju-samuel 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] mbaret commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
mbaret commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r492584474



##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -596,11 +596,13 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
 /*! \brief Attributes used in dilate operator */
 struct DilateAttrs : public tvm::AttrsNode<DilateAttrs> {
   Array<IndexExpr> strides;
+  double dilation_value;

Review comment:
       Why double vs float?

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2809,7 +2809,7 @@ def convert_transpose_conv(self, op):
         # Weights
         weights_tensor_type = weights_tensor.tensor.Type()
         # weights tensor type should be UINT8 (quantization) or FLOAT32

Review comment:
       Update this comment to include INT8

##########
File path: python/tvm/topi/nn/dilate.py
##########
@@ -34,6 +34,9 @@ def dilate(data, strides, name="DilatedInput"):
     strides : list / tuple of n ints
         Dilation stride on each dimension, 1 means no dilation.
 
+    dilation_value : int/float

Review comment:
       document 'optional'

##########
File path: python/tvm/topi/testing/dilate_python.py
##########
@@ -30,6 +30,9 @@ def dilate_python(input_np, strides):
     strides : list / tuple of n ints
         Dilation stride on each dimension, 1 means no dilation.
 
+    dilation_value : int/float

Review comment:
       document 'optional'

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       May be useful to document the mathematical approach taken here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on a change in pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on a change in pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#discussion_r493343488



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -2831,17 +2831,94 @@ def convert_transpose_conv(self, op):
         else:
             padding = (0, 0, 0, 0)
 
-        out = _op.nn.conv2d_transpose(
-            in_expr,
-            weight_expr_iohw,
-            strides=(stride_h, stride_w),
-            padding=padding,
-            channels=int(out_channels),
-            kernel_size=(int(kernel_h), int(kernel_w)),
-            data_layout="NHWC",
-            kernel_layout="OIHW",
-            out_dtype=output_tensor_type_str,
-        )
+        if input_tensor.qnn_params:
+            # Making use of qnn.conv2d

Review comment:
       Okay, added that.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-698463367


   Quantized Transpose Convolution code needs some changes, so bringing `dilate `operator changes independently in #6550.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 edited a comment on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
anijain2305 edited a comment on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-697949127


   Dilation part is good.
   
   I am not sure about the conv2d transpose portion. My concern is that we now have to replicate the logic for different framework parsers. My suggestion would be to add `qnn.conv2d_tranpose` op and perform the "dilation + qnn.op.conv2d" lowering in QNN Legalize (example here - https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/qnn/op/legalizations.py#L266).
   
   For now, we can make the transformation for all targets, not just specifically to ARM.
   
   This will keep the option open to improve the schedule of conv2d_transpose as a whole if needed.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-696425194


   cc @anijain2305 @mbaret @u99127 @FrozenGene @tqchen


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523#issuecomment-697949127


   Dilation part is good.
   
   I am not sure about the conv2d transpose portion. My concern is that we now have to replicate the logic for different framework parsers. My suggestion would be to add `qnn.conv2d_tranpose` op and perform the "dilation + qnn.op.conv2d" lowering in QNN Legalize (example here - https://github.com/apache/incubator-tvm/blob/master/python/tvm/relay/qnn/op/legalizations.py#L266).
   
   For now, we can make the transformation for all targets, not just specifically to ARM.
   
   This will keep the option in future for improving the schedule of conv2d_transpose as a whole if needed.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] jainris closed pull request #6523: [QNN][TFLite] Added support for fused-bias and quantized input in TRANSPOSE_CONV for TFLite.

Posted by GitBox <gi...@apache.org>.
jainris closed pull request #6523:
URL: https://github.com/apache/incubator-tvm/pull/6523


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org