You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/08 17:50:37 UTC

[GitHub] [incubator-tvm] d-smirnov opened a new pull request #6018: Added support for tflite quantized maximum and minimum

d-smirnov opened a new pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018


   Added support for tflite quantized maximum and minimum


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r453241061



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -1281,70 +1287,72 @@ def test_all_unary_elemwise():
 # Element-wise
 # ------------
 
-def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
+def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None, same_qnn_params=False):
     """ One iteration of elemwise """
 
     assert len(data) == 2
 
     # Test with two tensors
-    with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
-                   array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]
-
-        if quantized:
+    def __test_elemwise( in_data ):

Review comment:
       My suggestion would be to use the TFLite 2.1 way of creating quantization tests. I think its cleaner.
   
   An example for Relu is here
   https://github.com/apache/incubator-tvm/blob/c9c77c6b76f7cff3bc6afbf9d3ef2200e3fdbb91/tests/python/frontend/tflite/test_forward.py#L2072-L2089




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r466645505



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       That makes sense now. Thanks for your patience. I would suggest to rename the `use_real_qnn` to `ignore_qnn_params`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r466509216



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       I understood that it is not an experimental feature any more. However the name "experimental_new_converter" was preserved. I don't see any harm to use this feature and have this test especially if we plan to migrate to a newer version of TFLite.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       The extraction of "use_real_qnn" functionality to _convert_minimum and _convert_maximum methods (L1225 and L1229) will lead either: to change to operation's parameters stripping the qnn_attrs from at least lhs input tensor or to addition an extra flag forces _convert_elemwise to use non quantized version of the operation. Alternatively I might not understand your point here.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r460440865



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       The quantised versions of maximum and minimum cannot be converted using TOCO (TOCO is not supported anymore).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r460425703



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       Why do we need experimental?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r469979193



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       renamed




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r465863240



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       Can we skip the use_real_qnn by moving the check to L1225 and L1229 and keeping _convert_elemwise unchanged? Adding using_real_qnn seems little adhoc.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r466594422



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       I might be not correct here, but the whole idea of using same qnn parameters is about being able to re-use non-quantized version of the operation. In case of Slice op #6217 there is only a check and the non-quantized operation is always used. In case of maximum and minimum this is not possible without changes either in the command operands (the qnn_params should stripped off) or, alternatively changes in _convert_elemwise in order explicitly prevent it going via quantized version. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#issuecomment-674224325


   Thanks @d-smirnov @u99127 This is merged


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r465833724



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       MLIR-based (experimental_new_converter=True) converter is already in use in Tensorflow and in TFLite. Why there is a need to postpone?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r466646173



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       I am still not fully comfortable about this. https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter shows that that API is subject to change. What do @u99127 @siju-samuel think about this?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#issuecomment-670157341


   @siju-samuel @FrozenGene Please review when you get time


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r465831277



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       "use_real_qnn=False" allows __convert_elemwise_ to use non-quantized version of the operation if all supplied parameters and output of the operation have same quantization values. Some other quantized tflite operations also supposed to use this feature




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r460425660



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1089,7 +1093,7 @@ def convert_square(self, op):
 
         return out
 
-    def _convert_elemwise(self, relay_op, op):
+    def _convert_elemwise(self, relay_op, op, use_real_qnn=True):

Review comment:
       Why is `use_real_qnn` required? I would suggest to move the same_qnn_params test L1225 and L1229 and keep _convert_elemwise unchanged.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 merged pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 merged pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r454582120



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -358,12 +358,16 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
         rhs_scale = rhs_tensor.qnn_params['scale']
         lhs_zero_point = lhs_tensor.qnn_params['zero_point']
         rhs_zero_point = rhs_tensor.qnn_params['zero_point']
-        lhs_scale_value = get_scalar_from_constant(lhs_scale)
-        rhs_scale_value = get_scalar_from_constant(rhs_scale)
-        lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
-        rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
-        return lhs_scale_value == rhs_scale_value and \
-                lhs_zero_point_value == rhs_zero_point_value
+        # 0.1 + 0.2 != 0.3

Review comment:
       The main point here is a replacement of exact comparison of floating point values with the inexact ones. The extension of the function to tuples is sort of a bonus from numpy.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tqchen commented on pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
tqchen commented on pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#issuecomment-663867081


   @anijain2305 please follow up


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] u99127 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
u99127 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r470636451



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       @anijain2305  - sorry for some reason github is refusing to send notifications to email when tagged here :( and that's the reason for my delay in responding to this in addition to holidays.
   
   AFAIUI there is no way of freezing a tflite model that contains quantized max or min using the toco converter and thus we need to use the API in that form to get the testsuite coverage. While the API to use this is "subject to change", from my pov it's a use in the testsuite , we aren't using it in the main code base and thus using it is less risky.
   
   Note also that the tflite converter in tensorflow is now defaulting to the mlir based converter as per the latest docs so this use is still a conservative move forward as we are sticking to the default but it gives us additional operator coverage. 
   
   Does that help ? 
   
   Ramana
   




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r470812443



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       Ok. I understand. If CI upgrade to future TFLite versions causes problem, we can remove the flag.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] d-smirnov commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
d-smirnov commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r454579971



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -1281,70 +1287,72 @@ def test_all_unary_elemwise():
 # Element-wise
 # ------------
 
-def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None):
+def _test_elemwise(math_op, data, fused_activation_function=None, quantized=False, qnn_op=None, same_qnn_params=False):
     """ One iteration of elemwise """
 
     assert len(data) == 2
 
     # Test with two tensors
-    with tf.Graph().as_default():
-        in_data = [array_ops.placeholder(shape=data[0].shape, dtype='float32', name='in_0'),
-                   array_ops.placeholder(shape=data[1].shape, dtype='float32', name='in_1')]
-
-        if quantized:
+    def __test_elemwise( in_data ):

Review comment:
       I slightly simplified unit test (using current implementation). Is there an example with different quantisation parameters for two input tensors and one output tensor written using Keras?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r453240904



##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -358,12 +358,16 @@ def has_same_qnn_params(self, lhs_tensor, rhs_tensor):
         rhs_scale = rhs_tensor.qnn_params['scale']
         lhs_zero_point = lhs_tensor.qnn_params['zero_point']
         rhs_zero_point = rhs_tensor.qnn_params['zero_point']
-        lhs_scale_value = get_scalar_from_constant(lhs_scale)
-        rhs_scale_value = get_scalar_from_constant(rhs_scale)
-        lhs_zero_point_value = get_scalar_from_constant(lhs_zero_point)
-        rhs_zero_point_value = get_scalar_from_constant(rhs_zero_point)
-        return lhs_scale_value == rhs_scale_value and \
-                lhs_zero_point_value == rhs_zero_point_value
+        # 0.1 + 0.2 != 0.3

Review comment:
       I am little confused here. IIUC, scale and zero points are tuple only for weights. So, for maximum and minimum, we should not need to change this function.
   
   Imo, the changes should be very similar to reshape op
   https://github.com/apache/incubator-tvm/blob/c9c77c6b76f7cff3bc6afbf9d3ef2200e3fdbb91/python/tvm/relay/frontend/tflite.py#L472-L477
   
   
   Basically check that that qnn params are same and nothing else.
   
   Let me know if I understood something incorrectly.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] anijain2305 commented on a change in pull request #6018: Added support for tflite quantized maximum and minimum

Posted by GitBox <gi...@apache.org>.
anijain2305 commented on a change in pull request #6018:
URL: https://github.com/apache/incubator-tvm/pull/6018#discussion_r465862408



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -250,7 +256,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
         # convert to tflite model
         converter = tf.lite.TFLiteConverter.from_session(
             sess, input_tensors, output_tensors)
-
+        converter.experimental_new_converter=experimental_new_converter

Review comment:
       ISn't the term experimental suggests that the feature is not mature yet? Typically, I have seen that the experimental features go through code churn and can be deprecated and API may also change before it gets matured. This is the main reason, I am suggesting not to put this.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org