You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by mb...@apache.org on 2021/06/30 14:55:11 UTC
[tvm] branch main updated: Fix issue with importing models using
Tensorflow Lite 2.4.x schema (#8375)
This is an automated email from the ASF dual-hosted git repository.
mbrookhart pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8d4df91 Fix issue with importing models using Tensorflow Lite 2.4.x schema (#8375)
8d4df91 is described below
commit 8d4df91836bac8ee416adf29141d051c952802a7
Author: Ramana Radhakrishnan <ra...@arm.com>
AuthorDate: Wed Jun 30 15:54:53 2021 +0100
Fix issue with importing models using Tensorflow Lite 2.4.x schema (#8375)
Tensorflow Lite has changed the opcode for BuiltinOperators
to be represented as 32 bit integers instead of 8 bit integers
in the schema.
This is an attempt to fix this in a way that is clean to handle
multiple versions of tensorflow lite in the frontend.
---
python/tvm/relay/frontend/tflite.py | 25 ++++++++++++++++++++++++-
1 file changed, 24 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 7e21739..a47fdf0 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -251,7 +251,30 @@ class OperatorConverter(object):
raise ImportError("The tflite package must be installed")
op_code_list_idx = op.OpcodeIndex()
- op_code_id = self.model.OperatorCodes(op_code_list_idx).BuiltinCode()
+
+ op_c = self.model.OperatorCodes(op_code_list_idx)
+ # In TFlite 2.4.x there was a change where the type of the field that contained
+ # the builtin code changed from int8 to int32 in the flat buffer representation.
+ # However to retain support for old flat buffers that were created, they retained
+ # the original 8 bit encoding for the operator but in a new field accessed by the
+ # DeprecatedBuiltinCode method.
+ # This means that the API function BuiltinCode() is used on an operator
+ # which was originally encoded as an 8 bit quantity it would look for the
+ # code in the new int32 field in the schema and this creates the need
+ # for the check for the magic number of 127 which is indicated by
+ # BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES
+ # Remember however that this value came into existence only after Tensorflow
+ # lite 2.4.x and hence encase it in a try -except block.
+ # Phew !
+ try:
+ if op_c.BuiltinCode() < BuiltinOperator.PLACEHOLDER_FOR_GREATER_OP_CODES:
+ opc = op_c.DeprecatedBuiltinCode()
+ else:
+ opc = op_c.BuiltinCode()
+ except AttributeError:
+ opc = op_c.BuiltinCode()
+
+ op_code_id = opc
try:
op_code_str = self.builtin_op_code[op_code_id]
except KeyError: