You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/01/13 14:20:13 UTC

[tvm] branch main updated: [Frontend][TFLite] Densify Op added (#7048)

This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 006b9b5  [Frontend][TFLite] Densify Op added (#7048)
006b9b5 is described below

commit 006b9b53ab97e677933011d8b36a98a5a4ac7723
Author: ANSHUMAN TRIPATHY <an...@huawei.com>
AuthorDate: Wed Jan 13 19:50:02 2021 +0530

    [Frontend][TFLite] Densify Op added (#7048)
    
    * [Frontend][TFLite] Densify Op added
    
    * [1] Review comments handled
    
    * TODO added for sparse_to_dense Op usage
    
    * stale comments removed
---
 python/tvm/relay/frontend/tflite.py          | 215 +++++++++++++++++++++++++--
 tests/python/frontend/tflite/test_forward.py |  48 ++++++
 2 files changed, 253 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py
index 7a2aada..525fb41 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -65,6 +65,7 @@ class OperatorConverter(object):
         self.builtin_op_code = build_str_map(BuiltinOperator())
         self.activation_fn_type = build_str_map(ActivationFunctionType())
         self.builtin_options = build_str_map(BuiltinOptions())
+        self.prefetched_nodes = {}
 
         # Add more operators
         self.convert_map = {
@@ -80,6 +81,7 @@ class OperatorConverter(object):
             "CONCATENATION": self.convert_concatenation,
             "CONV_2D": self.convert_conv2d,
             "COS": self.convert_cos,
+            "DENSIFY": self.convert_densify,
             "DEPTH_TO_SPACE": self.convert_depth_to_space,
             "DEPTHWISE_CONV_2D": self.convert_depthwise_conv2d,
             "DEQUANTIZE": self.convert_dequantize,
@@ -200,6 +202,10 @@ class OperatorConverter(object):
             assert isinstance(op, Operator)
             ret = self.convert_map[op_code_str](op)
 
+            # In case the Op can be prefetched, the output can be optimized out
+            if ret is None:
+                continue
+
             if len(output_tensors) == 1:
                 tensor_idx = output_tensors[0].tensor_idx
                 self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret)
@@ -338,7 +344,8 @@ class OperatorConverter(object):
                 "Tensor type '{}' currently not supported".format(tensor_wrapper.tensor.Type())
             )
 
-    def get_tensor_value(self, tensor_wrapper):
+    # pylint: disable=no-else-return
+    def get_tensor_value(self, tensor_wrapper, is_sparse=False):
         """Get tensor buffer value from given tensor wrapper"""
         assert isinstance(tensor_wrapper, TensorWrapper)
 
@@ -350,7 +357,10 @@ class OperatorConverter(object):
         else:
             shape = []
 
-        return np.frombuffer(data, dtype=dtype).reshape(shape)
+        if is_sparse:
+            return np.frombuffer(data, dtype=dtype)
+        else:
+            return np.frombuffer(data, dtype=dtype).reshape(shape)
 
     def get_tensor_type_str(self, tensor_type):
         """Get tensor type string representation when given TFLite tensor type"""
@@ -1662,11 +1672,15 @@ class OperatorConverter(object):
         axis = tuple(axis_value) if len(axis_value.shape) > 0 else tuple((axis_value.item(),))
 
         # Options - keep_dims (bool)
-        assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
-        reduce_options = ReducerOptions()
-        op_options = op.BuiltinOptions()
-        reduce_options.Init(op_options.Bytes, op_options.Pos)
-        keep_dims = reduce_options.KeepDims()
+        # In case Options are not present, set keep_dims to False(default)
+        if op.BuiltinOptionsType():
+            assert op.BuiltinOptionsType() == BuiltinOptions.ReducerOptions
+            reduce_options = ReducerOptions()
+            op_options = op.BuiltinOptions()
+            reduce_options.Init(op_options.Bytes, op_options.Pos)
+            keep_dims = reduce_options.KeepDims()
+        else:
+            keep_dims = False
 
         if input_tensor.qnn_params:
             in_expr = _op.cast(in_expr, "int32")
@@ -2026,7 +2040,11 @@ class OperatorConverter(object):
             else:
                 weight_expr = _op.transpose(weight_expr, axes=(1, 2, 3, 0))
         else:
-            weight_value = self.get_tensor_value(weight_tensor)
+            if self.is_prefetched(weight_tensor.tensor_idx):
+                weight_value = self.get_prefetched_node(weight_tensor.tensor_idx)
+            else:
+                weight_value = self.get_tensor_value(weight_tensor)
+
             # TFLite kernel layout:
             # convolution:
             # OC KH KW IC, we require KH KW IC OC (HWIO)
@@ -3196,22 +3214,199 @@ class OperatorConverter(object):
         out = _op.matrix_set_diag(input_expr, diagonal_expr)
         return out
 
+    def convert_densify(self, op):
+        """Convert TFLite DENSIFY"""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+
+        sparse_weight_tensor = input_tensors[0]
+        sparse_weight_tensor_type_str = self.get_tensor_type_str(sparse_weight_tensor.tensor.Type())
+
+        # NOTE: With current implementation in TFLite, Densify Op does not need to be present
+        # in runtime.
+        # TODO(ANSHUMAN87): we need to use the sparse_indices output
+        # from below function and use that in sparse_to_dense Op.
+        # Once the stack corruption issue is resolved in sparse_to_dense Op.
+        _, dense_weight = prepare_dense_matrix_from_sparse(
+            sparse_weight_tensor.tensor,
+            self.get_tensor_value(sparse_weight_tensor, is_sparse=True),
+            sparse_weight_tensor_type_str,
+        )
+
+        self.set_prefetched_node(output_tensor.tensor_idx, dense_weight)
+
     def get_expr(self, input_tensor_idx):
         return self.exp_tab.get_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
     def has_expr(self, input_tensor_idx):
         return self.exp_tab.has_expr(get_tensor_name(self.subgraph, input_tensor_idx))
 
-    def get_tensor_expr(self, tensor):
+    def is_prefetched(self, input_tensor_idx):
+        return (
+            self.prefetched_nodes.get(get_tensor_name(self.subgraph, input_tensor_idx)) is not None
+        )
+
+    def set_prefetched_node(self, input_tensor_idx, value):
+        self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)] = value
+
+    def get_prefetched_node(self, input_tensor_idx):
+        return self.prefetched_nodes[get_tensor_name(self.subgraph, input_tensor_idx)]
+
+    def get_tensor_expr(self, tensor, is_sparse=False):
         """ Return the Relay expr for tensor. """
         if self.has_expr(tensor.tensor_idx):
             expr = self.get_expr(tensor.tensor_idx)
         else:
             type_str = self.get_tensor_type_str(tensor.tensor.Type())
-            expr = self.exp_tab.new_const(self.get_tensor_value(tensor), dtype=type_str)
+            expr = self.exp_tab.new_const(self.get_tensor_value(tensor, is_sparse), dtype=type_str)
         return expr
 
 
+# pylint: disable=no-else-return
+def prepare_dense_matrix_from_sparse(sparse_tensor, sparse_tensor_value, sparse_tensor_type):
+    """ Prepare sparse indices and dense matrix from TFLite sparse parameters. """
+    # The function is implemented based on TFLite sparse parameter specifications
+    # Please refer
+    # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs#L89
+    # for details about each parameters
+    sparsity = sparse_tensor.Sparsity()
+    dense_shape = sparse_tensor.ShapeAsNumpy()
+    orig_rank = len(dense_shape)
+
+    # The traversal order of the dimensions defined in the `shape` field of the to be dense tensor.
+    traversal_order = sparsity.TraversalOrderAsNumpy()
+
+    # For an n-dimensional tensor with a k-dimensional block (0 <= k <= n),
+    # stores how a block dimension in (dn, ..., dn+k-1) maps to the original
+    # tensor dimension in (d0, ..., dn). It's stored in the order of (dn, ..., dn+k-1).
+    # If not block-sparse, this field is NULL.
+    block_map = sparsity.BlockMapAsNumpy()
+
+    total_rank = sparsity.TraversalOrderLength()
+    dense_mat = np.full(shape=dense_shape, fill_value=0, dtype=sparse_tensor_type).flatten()
+
+    from enum import Enum
+
+    # NOTE: Here the Vector term is borrowed from TFLite spec.
+    class VectorType(Enum):
+        Empty = 0
+        Int32 = 1
+        Uint16 = 2
+        Uint8 = 3
+
+    def _get_vector_flag(v_type):
+        if VectorType(v_type) == VectorType.Int32:
+            return N.Int32Flags
+        elif VectorType(v_type) == VectorType.Uint16:
+            return N.Uint16Flags
+        elif VectorType(v_type) == VectorType.Uint8:
+            return N.Uint8Flags
+        else:
+            raise tvm.error.OpNotImplemented("The provided type {} is not supported".format(v_type))
+
+    def _get_flattened_index(indices, shape):
+        index = 0
+        sub_elements = 1
+        for i in reversed(range(0, len(dense_shape))):
+            index += indices[i] * sub_elements
+            sub_elements *= shape[i]
+        return index
+
+    # DimensionMetadata per dimension: the metadata needed for
+    #     each dimension to locate the non-zero values in the original dense tensor
+    #     inline with traversal order parameter.
+    #
+    # sp_format has 2 possible values: {DENSE = 0, SPARSE_CSR = 1}
+    # If format = DENSE{0} : DenseSize represents size of that dimension
+    # If format = SPARSE_CSR{1} : array_segments represents how to segment the indices array,
+    #      each segment corresponds to one element in the previous dimension. array_indices
+    #      represents the index of the non-zero elements within this dimension
+    #      (as those in the CSR matrix format, where the first array is row pointers
+    #       and the second array is column indices).
+    sp_format = np.zeros(sparsity.DimMetadataLength())
+    dim_metadata = [None] * (2 * sparsity.DimMetadataLength())
+
+    # Below loop will fetch all meta data per dimension based on format type
+    # Dense or Sparse and will put it in an agnostic array for easy access
+    # while preparing dense buffer or indices.
+    for i in range(sparsity.DimMetadataLength()):
+        sp_format[i] = sparsity.DimMetadata(i).Format()
+        if sp_format[i] == 0:
+            dim_metadata[2 * i] = [sparsity.DimMetadata(i).DenseSize()]
+        else:
+            from flatbuffers import number_types as N
+
+            dim_metadata[2 * i] = (
+                sparsity.DimMetadata(i)
+                .ArraySegments()
+                .GetVectorAsNumpy(
+                    flags=_get_vector_flag(sparsity.DimMetadata(i).ArraySegmentsType()), off=4
+                )
+            )
+            dim_metadata[2 * i + 1] = (
+                sparsity.DimMetadata(i)
+                .ArrayIndices()
+                .GetVectorAsNumpy(
+                    flags=_get_vector_flag(sparsity.DimMetadata(i).ArrayIndicesType()), off=4
+                )
+            )
+
+    block_dim = 0
+    block_size = np.zeros(sparsity.BlockMapLength())
+
+    # Block size parameter if encoded in BSR format
+    for i in range(orig_rank):
+        if block_dim < sparsity.BlockMapLength() and block_map[block_dim] == i:
+            orig_dim = traversal_order[orig_rank + block_dim]
+            block_size[block_dim] = sparsity.DimMetadata(orig_dim).DenseSize()
+            block_dim += 1
+
+    indices_list = []
+
+    # Below function iterates through each applicable indices per dimension
+    # based on format type specified and finaly produce the dense matrix and the NZ indices.
+    def _def_prepare_dense_matrix_from_sparse(indices, level, prev_idx):
+        if level == len(indices):
+            start_pos = 0
+            orig_idx = np.zeros(orig_rank, dtype="int32")
+            while start_pos < orig_rank:
+                orig_idx[traversal_order[start_pos]] = indices[start_pos]
+                start_pos += 1
+            while start_pos < len(indices):
+                block_idx = traversal_order[start_pos] - orig_rank
+                orig_dim = block_map[block_idx]
+                orig_idx[orig_dim] = orig_idx[orig_dim] * block_size[block_idx] + indices[start_pos]
+                start_pos += 1
+            indices_list.append(orig_idx)
+            nonlocal value_idx
+            dense_mat[_get_flattened_index(orig_idx, dense_shape)] = sparse_tensor_value[value_idx]
+            value_idx += 1
+        else:
+            metadata_idx = 2 * level
+            if sp_format[level] == 0:
+                shape_of_level = dim_metadata[metadata_idx][0]
+                for idx in range(shape_of_level):
+                    indices[level] = idx
+                    _def_prepare_dense_matrix_from_sparse(
+                        indices, level + 1, prev_idx * shape_of_level + idx
+                    )
+            else:
+                array_segments = dim_metadata[metadata_idx]
+                array_indices = dim_metadata[metadata_idx + 1]
+                for idx in range(array_segments[prev_idx], array_segments[prev_idx + 1]):
+                    indices[level] = array_indices[idx]
+                    _def_prepare_dense_matrix_from_sparse(indices, level + 1, idx)
+
+    indices = np.zeros(total_rank)
+    value_idx = 0
+    _def_prepare_dense_matrix_from_sparse(indices, 0, 0)
+    return np.array(indices_list, dtype="int32"), dense_mat.reshape(dense_shape)
+
+
 def get_scalar_from_constant(expr):
     """ Returns scalar value from Relay constant scalar. """
     assert (
diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py
index c8bd094..f365301 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -3692,6 +3692,50 @@ def test_forward_mobilenet_v3():
 
 
 #######################################################################
+# Mobilenet V1 Sparse
+# -----------------
+
+
+def test_forward_sparse_mobilenet_v1():
+    """Test the Sparse version of Mobilenet V1 TF Lite model."""
+    # MobilenetV1
+    tflite_model_file = download_testdata(
+        "https://storage.googleapis.com/fast-convnets/tflite-models/mbv1_140_90_12b4_720.tflite",
+        "mbv1_140_90_12b4_720.tflite",
+    )
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
+
+#######################################################################
+# Mobilenet V2 Sparse
+# -----------------
+
+
+def test_forward_sparse_mobilenet_v2():
+    """Test the Sparse version of Mobilenet V2 TF Lite model."""
+    # MobilenetV1
+    tflite_model_file = download_testdata(
+        "https://storage.googleapis.com/fast-convnets/tflite-models/mbv2_200_85_11-16b2_744.tflite",
+        "mbv2_200_85_11-16b2_744.tflite",
+    )
+    with open(tflite_model_file, "rb") as f:
+        tflite_model_buf = f.read()
+    data = np.random.uniform(size=(1, 224, 224, 3)).astype("float32")
+    tflite_output = run_tflite_graph(tflite_model_buf, data)
+    tvm_output = run_tvm_graph(tflite_model_buf, data, "float_image_input")
+    tvm.testing.assert_allclose(
+        np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-5
+    )
+
+
+#######################################################################
 # Inception
 # ---------
 
@@ -4197,6 +4241,10 @@ if __name__ == "__main__":
     test_forward_coco_ssd_mobilenet_v1()
     test_forward_mediapipe_hand_landmark()
 
+    # End to End Sparse models
+    test_forward_sparse_mobilenet_v1()
+    test_forward_sparse_mobilenet_v2()
+
     # End to End quantized
     test_forward_qnn_inception_v1_net()
     test_forward_qnn_mobilenet_v1_net()