You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2021/02/19 11:16:04 UTC
[tvm] branch main updated: [AutoScheduler] Fix the type inference
for conv3d (#7475)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e204209 [AutoScheduler] Fix the type inference for conv3d (#7475)
e204209 is described below
commit e2042093cddcd2249bf1a7b7659cda6d39046a1c
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Feb 19 19:15:31 2021 +0800
[AutoScheduler] Fix the type inference for conv3d (#7475)
---
src/relay/op/nn/convolution.h | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index c08d355..5b4850e 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -24,6 +24,7 @@
#ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
#define TVM_RELAY_OP_NN_CONVOLUTION_H_
+#include <tvm/auto_scheduler/compute_dag.h>
#include <tvm/support/logging.h>
#include <tvm/tir/analysis.h>
@@ -369,7 +370,18 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
- auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+
+ Array<PrimExpr> wshape;
+ if (param->auto_scheduler_rewritten_layout.size() == 0) {
+ wshape = weight->shape;
+ } else {
+ // works for the default kernel layout "DHWIO"
+ ICHECK_EQ(param->kernel_layout, "DHWIO");
+ wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
+ {"rd", "rh", "rw", "rc", "cc"});
+ }
+
+ wshape = trans_kernel_layout.ForwardShape(wshape);
if (param->kernel_size.defined()) {
ICHECK_EQ(param->kernel_size.size(), 3);
// check the size