You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2019/11/28 04:11:23 UTC

[GitHub] [incubator-tvm] apivovarov commented on a change in pull request #4440: [TFLite] Add transpose_conv to TFLite parser

apivovarov commented on a change in pull request #4440: [TFLite] Add transpose_conv to TFLite parser
URL: https://github.com/apache/incubator-tvm/pull/4440#discussion_r351585989
 
 

 ##########
 File path: python/tvm/relay/frontend/tflite.py
 ##########
 @@ -1370,6 +1371,85 @@ def convert_prelu(self, op):
 
         return out
 
+    def convert_transpose_conv(self, op):
+        """Convert TFLite TRANSPOSE_CONV"""
+        try:
+            from tflite.BuiltinOptions import BuiltinOptions
+            from tflite.TensorType import TensorType
+            from tflite.Operator import Operator
+            from tflite.TransposeConvOptions import TransposeConvOptions
+            from tflite.Padding import Padding
+        except ImportError:
+            raise ImportError("The tflite package must be installed")
+
+        assert isinstance(op, Operator)
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 3, "input tensors length should be 3"
+
+        # Input (data) Tensor. NHWC layout
+        input_tensor = input_tensors[2]
+        _, _, _, input_c = input_tensor.tensor.ShapeAsNumpy()
+        # Weights tensor. TFLite uses OHWI layout
+        weights_tensor = input_tensors[1]
+        out_channels, kernel_h, kernel_w, in_channels = weights_tensor.tensor.ShapeAsNumpy()
+        assert input_c == in_channels, \
+            "Input channel in the filter should match to channel in the input"
+        # output_shape Tensor. NHWC layout
+        output_shape_tensor = input_tensors[0]
+
+        output_tensors = self.get_output_tensors(op)
+        assert len(output_tensors) == 1, "output tensors length should be 1"
+        output_tensor = output_tensors[0]
+        output_tensor_type = output_tensor.tensor.Type()
+        output_tensor_type_str = self.get_tensor_type_str(output_tensor_type)
+
+        assert op.BuiltinOptionsType() == BuiltinOptions.TransposeConvOptions
+        op_options = op.BuiltinOptions()
+        deconv_options = TransposeConvOptions()
+        deconv_options.Init(op_options.Bytes, op_options.Pos)
+
+        padding = deconv_options.Padding()
+        stride_h = deconv_options.StrideH()
+        stride_w = deconv_options.StrideW()
+        assert padding in (Padding.VALID, Padding.SAME), \
+            'Padding format {} is not supported for operator TRANSPOSE_CONV'.format(padding)
+
+        # Data
+        in_expr = self.get_expr(input_tensor.tensor_idx)
+
+        # Weights
+        weights_tensor_type = weights_tensor.tensor.Type()
+        # weights tensor type should be UINT8 (quantization) or FLOAT32
+        assert weights_tensor_type in (TensorType.UINT8, TensorType.FLOAT32)
+        weight_tensor_type_str = self.get_tensor_type_str(weights_tensor_type)
+        weight_value_ohwi = self.get_tensor_value(weights_tensor)
+        # Relay kernel_layout should be OIHW
+        # Relay weights layout should be different from kernel_layout - it should be IOHW
+        weight_value_iohw = np.transpose(weight_value_ohwi, (3, 0, 1, 2))
+        weight_expr_iohw = self.exp_tab.new_const(weight_value_iohw, dtype=weight_tensor_type_str)
+
+        # Output shape value
+        output_shape_value = self.get_tensor_value(output_shape_tensor)
+        # Relay expects filter output channel to match to output tensor channel.
+        assert out_channels == output_shape_value[3], \
+            "Output channel in the filter should match to channel in the output_shape"
+
+        # TF frontend supports 'SAME' padding for kernel 1x1 only. Lets do the same here
+        if padding == Padding.SAME:
+            assert (kernel_h, kernel_w) == (1, 1), \
+                "SAME padding is supported for kernel (1,1) only"
 
 Review comment:
   Where? in TF or TFLite?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services