You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2021/04/15 04:09:42 UTC
[tvm] branch main updated: Fix PyTorch matmul conversion when given
(2-dim, N-dim) input pair (#7845)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b24fbe7 Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)
b24fbe7 is described below
commit b24fbe7dee10195b6585634c0625828fe8624d5f
Author: liyuchao <xi...@163.com>
AuthorDate: Thu Apr 15 12:09:14 2021 +0800
Fix PyTorch matmul conversion when given (2-dim, N-dim) input pair (#7845)
* [AutoScheduler] Fix incorrectly array context device and hide info at the beginning
* Lint fix
* Lint fix
* update repo
* Fix Pytorch matmul conversion when given (2-dim, N-dim) input pair
* update measure.py
* Lint fix
* fix bug && add ut for pytorch matmul
* update ut
* Lint fix
* update commit
* Lint fix
---
python/tvm/relay/frontend/pytorch.py | 24 ++++++++++++++++-----
tests/python/frontend/pytorch/test_forward.py | 31 +++++++++++++++++++++++----
2 files changed, 46 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py
index cb9ea6a..a31c44a 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -1580,7 +1580,7 @@ class PyTorchOpConverter:
b_shape = self.infer_shape_with_prelude(inputs_1)
# When performing a batch matmul, we need to properly handle N-dim shapes.
- if len(a_shape) > 2 or len(b_shape) > 2:
+ if len(a_shape) > 2 and len(b_shape) > 2:
# Convert a into a 3 dimensional tensors.
need_reshape_output = False
if len(a_shape) != 3:
@@ -1606,18 +1606,32 @@ class PyTorchOpConverter:
if need_reshape_output:
return _op.reshape(output, [*a_shape[:-2], a_shape[-2], b_shape[-1]])
return output
+ elif len(a_shape) > 2:
+ inputs_0 = _op.reshape(inputs_0, [-1, a_shape[-1]])
- # Otherwise a simple dense op will get the job done.
- if len(b_shape) == 1:
- input_1 = _op.expand_dims(inputs_1, 0, 1)
- else:
+ if len(b_shape) > 2:
+ trans_axes = list(range(len(b_shape)))
+ trans_axes[-2], trans_axes[-1] = trans_axes[-1], trans_axes[-2]
+ input_1 = _op.reshape(_op.transpose(inputs_1, trans_axes), [-1, b_shape[-2]])
+ elif len(b_shape) == 2:
input_1 = _op.transpose(inputs_1, axes=(1, 0))
+ elif len(b_shape) == 1:
+ input_1 = _op.expand_dims(inputs_1, 0, 1)
out = _op.nn.dense(inputs_0, input_1)
if len(b_shape) == 1:
out = _op.squeeze(out, axis=[-1])
+ # Reshape output into a N dimensional tensor when a or b dim > 2
+ if len(a_shape) > 2:
+ out = _op.reshape(out, [*a_shape[:-1], b_shape[-1]])
+ elif len(b_shape) > 2:
+ out = _op.reshape(out, [a_shape[-2], -1, b_shape[-1]])
+ out = _op.reshape(
+ _op.transpose(out, [1, 0, 2]), [*b_shape[:-2], a_shape[-2], b_shape[-1]]
+ )
+
return out
def expand(self, inputs, input_types):
diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py
index 9ec5298..bff5bb6 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -162,7 +162,9 @@ def measure_latency(model, input_shapes, output_shapes, thresh, dryruns=40):
return est
-def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5):
+def verify_model(
+ model_name, input_data=[], custom_convert_map={}, rtol=1e-5, atol=1e-5, expected_ops=[]
+):
"""Assert that the output of a compiled model matches with that of its
baseline."""
if isinstance(model_name, str):
@@ -219,6 +221,20 @@ def verify_model(model_name, input_data=[], custom_convert_map={}, rtol=1e-5, at
assert_shapes_match(baseline_output, compiled_output)
tvm.testing.assert_allclose(baseline_output, compiled_output, rtol=rtol, atol=atol)
+
+ if expected_ops:
+
+ def visit(op):
+ if isinstance(op, tvm.ir.op.Op):
+ if op.name in expected_ops:
+ expected_ops.remove(op.name)
+
+ tvm.relay.analysis.post_order_visit(mod["main"].body, visit)
+
+ if expected_ops:
+ msg = "TVM Relay do not contain expected ops {}"
+ raise AssertionError(msg.format(expected_ops))
+
del model_name
del baseline_model
torch.cuda.empty_cache()
@@ -3304,17 +3320,24 @@ def test_forward_matmul():
# matrix x matrix
tensor1 = torch.randn(10, 4)
tensor2 = torch.randn(4, 10)
- verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
- verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+ verify_model(
+ MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.batch_matmul"]
+ )
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
- verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2])
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
+
+ # broadcasted matrix x batched matrix
+ tensor1 = torch.randn(10, 4)
+ tensor2 = torch.randn(3, 4, 5)
+ verify_model(MatMul1().float().eval(), input_data=[tensor1, tensor2], expected_ops=["nn.dense"])
# batched matrix x batched matrix
tensor1 = torch.randn(1, 12, 14, 64)