You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ss...@apache.org on 2022/12/15 02:51:14 UTC

[tvm] branch main updated: [Relay] Remove overwriting of matmul shapes when they are static (#13615)

This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cc0f27a8b1 [Relay] Remove overwriting of matmul shapes when they are static (#13615)
cc0f27a8b1 is described below

commit cc0f27a8b130f1dee5047f4fa6bfdbe82ed1a24a
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Wed Dec 14 18:51:08 2022 -0800

    [Relay] Remove overwriting of matmul shapes when they are static (#13615)
    
    In the Relay Matmul shape relation, we are a little over enthusiastic about unifying dynamic shapes. If one of the shapes is static, it does not need to be unified. This change only rewrites dynamic shapes to required static constraints.
    
    * Remove overwriting of matmul shapes when they are static
    
    * Simplify nesting
    
    * Add shape check to dense tests.
---
 src/relay/op/nn/nn.h                 | 33 +++++++++++++++++++++------------
 tests/python/relay/test_op_level1.py |  3 +++
 2 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index f5497a4603..cf601ff5f1 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -113,23 +113,32 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
         std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end());
         auto sa = A_shape.size();
         auto sb = B_shape.size();
+        size_t index_swap_A;
+        size_t index_swap_B;
         if (transpose_a && transpose_b) {
-          auto tmp = A_shape[sa - 2];
-          A_shape[sa - 2] = B_shape[sb - 1];
-          B_shape[sb - 1] = tmp;
+          index_swap_A = sa - 2;
+          index_swap_B = sb - 1;
         } else if (transpose_a) {
-          auto tmp = A_shape[sa - 2];
-          A_shape[sa - 2] = B_shape[sb - 2];
-          B_shape[sb - 2] = tmp;
+          index_swap_A = sa - 2;
+          index_swap_B = sb - 2;
         } else if (transpose_b) {
-          auto tmp = A_shape[sa - 1];
-          A_shape[sa - 1] = B_shape[sb - 1];
-          B_shape[sb - 1] = tmp;
+          index_swap_A = sa - 1;
+          index_swap_B = sb - 1;
         } else {
-          auto tmp = A_shape[sa - 1];
-          A_shape[sa - 1] = B_shape[sb - 2];
-          B_shape[sb - 2] = tmp;
+          index_swap_A = sa - 1;
+          index_swap_B = sb - 2;
         }
+
+        // Rewrite dynamic axes to static where constraints allow.
+        auto tmp = A_shape[index_swap_A];
+        if (A_shape[index_swap_A].as<tir::AnyNode>()) {
+          A_shape[index_swap_A] = B_shape[index_swap_B];
+        }
+        if (B_shape[index_swap_B].as<tir::AnyNode>()) {
+          B_shape[index_swap_B] = tmp;
+        }
+
+        // Update input types with new constrained shapes.
         reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype));
         reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype));
       }
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 30d9d88ad7..bd4e1b72c3 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -25,6 +25,7 @@ from tvm.relay.testing import run_infer_type
 import tvm.topi.testing
 from tvm.contrib.nvcc import have_fp16
 import tvm.testing
+from tvm.topi.utils import get_const_tuple
 
 executor_kind = tvm.testing.parameter("graph", "vm")
 
@@ -695,6 +696,8 @@ def test_dense(executor_kind):
         w = relay.var("w", relay.TensorType((k, n), dtype))
         y = relay.nn.dense(x, w)
         yy = run_infer_type(y)
+        # Confirm that input shape has not been rewritten to become dynamic.
+        assert get_const_tuple(yy.type_args[0].shape) == (4, 2)
 
         n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
         x = relay.var("x", relay.TensorType((n, c, h, w), dtype))