You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/05/04 15:57:40 UTC

[tvm] branch main updated: [Relay][Autoscheduler] Fix autoscheduler matmul without units. (#7957)

This is an automated email from the ASF dual-hosted git repository.

jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3ed83b6  [Relay][Autoscheduler] Fix autoscheduler matmul without units. (#7957)
3ed83b6 is described below

commit 3ed83b6f041cd15aa75c1ac0599d75792bdaae8b
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue May 4 08:56:57 2021 -0700

    [Relay][Autoscheduler] Fix autoscheduler matmul without units. (#7957)
    
    * Fix autoscheduler matmul without units.
    
    * Fix lint.
---
 src/relay/op/nn/nn.h                               | 27 +++++++++++++++-------
 .../test_auto_scheduler_layout_rewrite_networks.py |  2 +-
 2 files changed, 20 insertions(+), 9 deletions(-)

diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 8802cd9..38cb763 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -49,9 +49,9 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
 
   ICHECK(static_cast<int>(data->shape.size()) != 0);
 
-  Array<tvm::PrimExpr> oshape = data->shape;
+  Array<tvm::PrimExpr> dshape = data->shape;
+  Array<tvm::PrimExpr> oshape = dshape;
   if (param->units.defined()) {
-    Array<tvm::PrimExpr> dshape = data->shape;
     // validate the weight shape is proper if defined
     // Assign weight type
     Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
@@ -72,13 +72,24 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   } else {
     if (weight == nullptr) return false;
     Array<tvm::PrimExpr> wshape = weight->shape;
-    ICHECK(static_cast<int>(weight->shape.size()) == 2);
-    if (!data->shape.back().as<tir::AnyNode>()) {
-      ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
-          << "DenseRel: input dimension doesn't match,"
-          << " data shape=" << data->shape << ", weight shape=" << weight->shape;
+    // When weight's layout has been rewritten, figure it out based on the
+    // total number of elements and input dimensions.
+    if (param->auto_scheduler_rewritten_layout.size() != 0) {
+      PrimExpr weight_elements = 1;
+      for (size_t i = 0; i < wshape.size(); i++) {
+        weight_elements = weight_elements * wshape[i];
+      }
+      oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 1]);
+      // Otherwise just pull it out of the weight shape directly.
+    } else {
+      ICHECK(static_cast<int>(weight->shape.size()) == 2);
+      if (!data->shape.back().as<tir::AnyNode>()) {
+        ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
+            << "DenseRel: input dimension doesn't match,"
+            << " data shape=" << data->shape << ", weight shape=" << weight->shape;
+      }
+      oshape.Set((oshape.size() - 1), wshape[0]);
     }
-    oshape.Set((oshape.size() - 1), wshape[0]);
   }
 
   DataType out_dtype = param->out_dtype;
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
index 8466fc1..106b4bb 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
@@ -117,7 +117,7 @@ def get_relay_dense(m=128, n=128, k=128):
     dtype = "float32"
     d = relay.var("data", shape=(m, k), dtype=dtype)
     w = relay.var("weight", shape=(n, k), dtype=dtype)
-    y = relay.nn.dense(d, w, units=n)
+    y = relay.nn.dense(d, w)
     mod = tvm.IRModule()
     mod["main"] = relay.Function([d, w], y)
     data, weight = get_np_array(d, dtype), get_np_array(w, dtype)