You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/05/04 20:50:42 UTC

[GitHub] [incubator-tvm] u99127 commented on a change in pull request #5510: [FRONTEND][TFLite] Fully connected op conversion made in sync with TFLite

u99127 commented on a change in pull request #5510:
URL: https://github.com/apache/incubator-tvm/pull/5510#discussion_r419715441



##########
File path: tests/python/frontend/tflite/test_forward.py
##########
@@ -419,6 +419,31 @@ def test_forward_cast():
     _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8)
     _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64)
 
+#######################################################################
+# Batch Mat Mul
+# ----
+def _test_batch_matmul(A_shape, B_shape, dtype, adjoint_a=False, adjoint_b=False):
+    with tf.Graph().as_default():
+        A = array_ops.placeholder(shape=A_shape, dtype=dtype, name='A')
+        B = array_ops.placeholder(shape=B_shape, dtype=dtype, name='B')
+        result = math_ops.matmul(A, B, adjoint_a=adjoint_a,
+                           adjoint_b=adjoint_b, name='batchmatmul')
+
+        A_np = np.random.uniform(high=5.0, size=A_shape).astype(dtype)
+        B_np = np.random.uniform(high=5.0, size=B_shape).astype(dtype)
+        compare_tflite_with_tvm([A_np, B_np], [A.name, B.name], [A, B], [result])
+
+
+def test_forward_batch_matmul():
+    """ BATCH_MAT_MUL """
+    print("Jai hanuman!!!")
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32')
+    _test_batch_matmul((3, 5, 4), (3, 4, 5), 'float32', True, True)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', True, False)
+    _test_batch_matmul((3, 5, 4), (3, 5, 4), 'float32', False, True)
+    _test_batch_matmul((2, 3, 4, 5, 6), (2, 3, 4, 6, 5), 'float32')
+    print("Jai hanuman!!!")

Review comment:
       Unneeded print.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1329,16 +1329,20 @@ def convert_fully_connected(self, op):
         input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
         weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()
 
-        # reshape input tensor from N H W C to N H*W*C
-        input_size_per_batch = 1
-        for s in range(1, len(input_tensor_shape)):
-            input_size_per_batch *= input_tensor_shape[s]
-        assert input_size_per_batch == weight_tensor_shape[1], \
-            "input size and weight size are mismatched"
-        target_shape = tuple((input_tensor_shape[0], input_size_per_batch))
+        # Weight should have only 2 dimensions(TFLite convention)
+        assert len(weight_tensor_shape) == 2, "Weight should be only 2-dim"
+
+        input_size = 1
+        for s in range(0, len(input_tensor_shape)):
+            input_size *= input_tensor_shape[s]
+

Review comment:
       A comment here would be appropriate or indeed a reference to TFlite documentation.

##########
File path: python/tvm/relay/frontend/tflite.py
##########
@@ -1329,16 +1329,20 @@ def convert_fully_connected(self, op):
         input_tensor_shape = input_tensor.tensor.ShapeAsNumpy()
         weight_tensor_shape = weight_tensor.tensor.ShapeAsNumpy()
 
-        # reshape input tensor from N H W C to N H*W*C
-        input_size_per_batch = 1
-        for s in range(1, len(input_tensor_shape)):
-            input_size_per_batch *= input_tensor_shape[s]
-        assert input_size_per_batch == weight_tensor_shape[1], \
-            "input size and weight size are mismatched"
-        target_shape = tuple((input_tensor_shape[0], input_size_per_batch))

Review comment:
       Can you explain why this assert is wrong or needs to be removed in your covering note ?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org