You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/04/15 04:09:42 UTC

[tvm] branch main updated: Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b24fbe7  Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)
b24fbe7 is described below

commit b24fbe7dee10195b6585634c0625828fe8624d5f
Author: liyuchao <xi...@163.com>
AuthorDate: Thu Apr 15 12:09:14 2021 +0800

    Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)
    
    * [AutoScheduler] Fix incorrectly array context device and hide info at the beginning
    
    * Lint fix
    
    * Lint fix
    
    * update repo
    
    * Fix Pytorch matmul conversion when given (2-dim, N-dim) input pair
    
    * update measure.py
    
    * Lint fix
    
    * fix bug && add ut for pytorch matmul
    
    * update ut
    
    * Lint fix
    
    * update commit
    
    * Lint fix
---
 python/tvm/relay/frontend/pytorch.py          | 24 ++++++++++++++++-----
 tests/python/frontend/pytorch/test_forward.py | 31 +++++++++++++++++++++++----
 2 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index cb9ea6a..a31c44a 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1580,7 +1580,7 @@ class PyTorchOpConverter:
         b_shape = self.infer_shape_with_prelude(inputs_1)
 
         # When performing a batch matmul, we need to properly handle N-dim shapes.
-        if len(a_shape) > 2 or len(b_shape) > 2:
+        if len(a_shape) > 2 and len(b_shape) > 2:
             # Convert a into a 3 dimensional tensors.
             need_reshape_output = False
             if len(a_shape) != 3:
@@ -1606,18 +1606,32 @@ class PyTorchOpConverter:
             if need_reshape_output:
                 return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
             return output
+        elif len(a_shape) > 2:
+            inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
 
-        # Otherwise a simple dense op will get the job done.
-        if len(b_shape) == 1:
-            input_1 = _op.expand_dims(inputs_1, 0, 1)
-        else:
+        if len(b_shape) > 2:
+            trans_axes = list(range(len(b_shape)))
+            trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+            input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
+        elif len(b_shape) == 2:
             input_1 = _op.transpose(inputs_1, axes=(1, 0))
+        elif len(b_shape) == 1:
+            input_1 = _op.expand_dims(inputs_1, 0, 1)
 
         out = _op.nn.dense(inputs_0, input_1)
 
         if len(b_shape) == 1:
             out = _op.squeeze(out, axis=[-1])
 
+        # Reshape output into a N dimensional tensor when a or b dim > 2
+        if len(a_shape) > 2:
+            out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+        elif len(b_shape) > 2:
+            out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+            out = _op.reshape(
+                _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
+            )
+
         return out
 
     def expand(self, inputs, input_types):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9ec5298..bff5bb6 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
                 return est
 
 
-def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
+def verify_model(
+    model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
+):
     """Assert that the output of a compiled model matches with that of its
     baseline."""
     if isinstance(model_name, str):
@@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
 
                 assert_shapes_match(baseline_output, compiled_output)
                 tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
+
+    if expected_ops:
+
+        def visit(op):
+            if isinstance(op, tvm.ir.op.Op):
+                if op.name in expected_ops:
+                    expected_ops.remove(op.name)
+
+        tvm.relay.analysis.post_order_visit(mod["main"].body, visit)
+
+        if expected_ops:
+            msg = "TVM Relay do not contain expected ops {}"
+            raise AssertionError(msg.format(expected_ops))
+
     del model_name
     del baseline_model
     torch.cuda.empty_cache()
@@ -3304,17 +3320,24 @@ def test_forward_matmul():
     # matrix x matrix
     tensor1 = torch.randn(10, 4)
     tensor2 = torch.randn(4, 10)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(10, 4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(
+        MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
+    )
 
     # batched matrix x broadcasted matrix
     tensor1 = torch.randn(10, 3, 4)
     tensor2 = torch.randn(4, 5)
-    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
+
+    # broadcasted matrix x batched matrix
+    tensor1 = torch.randn(10, 4)
+    tensor2 = torch.randn(3, 4, 5)
+    verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
 
     # batched matrix x batched matrix
     tensor1 = torch.randn(1, 12, 14, 64)