You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by jw...@apache.org on 2021/05/04 15:57:40 UTC
[tvm] branch main updated: [Relay][Autoscheduler] Fix autoscheduler
matmul without units. (#7957)
This is an automated email from the ASF dual-hosted git repository.
jwfromm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3ed83b6 [Relay][Autoscheduler] Fix autoscheduler matmul without units. (#7957)
3ed83b6 is described below
commit 3ed83b6f041cd15aa75c1ac0599d75792bdaae8b
Author: Josh Fromm <jw...@octoml.ai>
AuthorDate: Tue May 4 08:56:57 2021 -0700
[Relay][Autoscheduler] Fix autoscheduler matmul without units. (#7957)
* Fix autoscheduler matmul without units.
* Fix lint.
---
src/relay/op/nn/nn.h | 27 +++++++++++++++-------
.../test_auto_scheduler_layout_rewrite_networks.py | 2 +-
2 files changed, 20 insertions(+), 9 deletions(-)
diff --git a/src/relay/op/nn/nn.h b/src/relay/op/nn/nn.h
index 8802cd9..38cb763 100644
--- a/src/relay/op/nn/nn.h
+++ b/src/relay/op/nn/nn.h
@@ -49,9 +49,9 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
ICHECK(static_cast<int>(data->shape.size()) != 0);
- Array<tvm::PrimExpr> oshape = data->shape;
+ Array<tvm::PrimExpr> dshape = data->shape;
+ Array<tvm::PrimExpr> oshape = dshape;
if (param->units.defined()) {
- Array<tvm::PrimExpr> dshape = data->shape;
// validate the weight shape is proper if defined
// Assign weight type
Array<IndexExpr> wshape({param->units, dshape[dshape.size() - 1]});
@@ -72,13 +72,24 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
if (weight == nullptr) return false;
Array<tvm::PrimExpr> wshape = weight->shape;
- ICHECK(static_cast<int>(weight->shape.size()) == 2);
- if (!data->shape.back().as<tir::AnyNode>()) {
- ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
- << "DenseRel: input dimension doesn't match,"
- << " data shape=" << data->shape << ", weight shape=" << weight->shape;
+ // When weight's layout has been rewritten, figure it out based on the
+ // total number of elements and input dimensions.
+ if (param->auto_scheduler_rewritten_layout.size() != 0) {
+ PrimExpr weight_elements = 1;
+ for (size_t i = 0; i < wshape.size(); i++) {
+ weight_elements = weight_elements * wshape[i];
+ }
+ oshape.Set(oshape.size() - 1, weight_elements / dshape[dshape.size() - 1]);
+ // Otherwise just pull it out of the weight shape directly.
+ } else {
+ ICHECK(static_cast<int>(weight->shape.size()) == 2);
+ if (!data->shape.back().as<tir::AnyNode>()) {
+ ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
+ << "DenseRel: input dimension doesn't match,"
+ << " data shape=" << data->shape << ", weight shape=" << weight->shape;
+ }
+ oshape.Set((oshape.size() - 1), wshape[0]);
}
- oshape.Set((oshape.size() - 1), wshape[0]);
}
DataType out_dtype = param->out_dtype;
diff --git a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
index 8466fc1..106b4bb 100644
--- a/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
+++ b/tests/python/relay/test_auto_scheduler_layout_rewrite_networks.py
@@ -117,7 +117,7 @@ def get_relay_dense(m=128, n=128, k=128):
dtype = "float32"
d = relay.var("data", shape=(m, k), dtype=dtype)
w = relay.var("weight", shape=(n, k), dtype=dtype)
- y = relay.nn.dense(d, w, units=n)
+ y = relay.nn.dense(d, w)
mod = tvm.IRModule()
mod["main"] = relay.Function([d, w], y)
data, weight = get_np_array(d, dtype), get_np_array(w, dtype)