You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/06/30 14:55:11 UTC

[tvm] branch main updated: Fix issue with importing models using Tensorflow Lite 2.4.x schema (#8375)

This is an automated email from the ASF dual-hosted git repository.

mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d4df91  Fix issue with importing models using Tensorflow Lite 2.4.x schema (#8375)
8d4df91 is described below

commit 8d4df91836bac8ee416adf29141d051c952802a7
Author: Ramana Radhakrishnan <ra...@arm.com>
AuthorDate: Wed Jun 30 15:54:53 2021 +0100

    Fix issue with importing models using Tensorflow Lite 2.4.x schema (#8375)
    
    Tensorflow Lite has changed the opcode for BuiltinOperators
    to be represented as 32 bit integers instead of 8 bit integers
    in the schema.
    
    This is an attempt to fix this in a way that is clean to handle
    multiple versions of tensorflow lite in the frontend.
---
 python/tvm/relay/frontend/tflite.py | 25 ++++++++++++++++++++++++-
 1 file changed, 24 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 7e21739..a47fdf0 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -251,7 +251,30 @@ class OperatorConverter(object):
             raise ImportError("The tflite package must be installed")
 
         op_code_list_idx = op.OpcodeIndex()
-        op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode()
+
+        op_c = self.model.OperatorCodes(op_code_list_idx)
+        # In TFlite 2.4.x there was a change where the type of the field that contained
+        # the builtin code changed from int8 to int32 in the flat buffer representation.
+        # However to retain support for old flat buffers that were created, they retained
+        # the original 8 bit encoding for the operator but in a new field accessed by the
+        # DeprecatedBuiltinCode method.
+        # This means that the API function BuiltinCode() is used on an operator
+        # which was originally encoded as an 8 bit quantity it would look for the
+        # code in the new int32 field in the schema and this creates the need
+        # for the check for the magic number of 127 which is indicated by
+        # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES
+        # Remember however that this value came into existence only after Tensorflow
+        # lite 2.4.x and hence encase it in a try -except block.
+        # Phew !
+        try:
+            if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
+                opc = op_c.DeprecatedBuiltinCode()
+            else:
+                opc = op_c.BuiltinCode()
+        except AttributeError:
+            opc = op_c.BuiltinCode()
+
+        op_code_id = opc
         try:
             op_code_str = self.builtin_op_code[op_code_id]
         except KeyError: