You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2021/02/19 11:16:04 UTC

[tvm] branch main updated: [AutoScheduler] Fix the type inference for conv3d (#7475)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e204209  [AutoScheduler] Fix the type inference for conv3d (#7475)
e204209 is described below

commit e2042093cddcd2249bf1a7b7659cda6d39046a1c
Author: Lianmin Zheng <li...@gmail.com>
AuthorDate: Fri Feb 19 19:15:31 2021 +0800

    [AutoScheduler] Fix the type inference for conv3d (#7475)
---
 src/relay/op/nn/convolution.h | 14 +++++++++++++-
 1 file changed, 13 insertions(+), 1 deletion(-)

diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index c08d355..5b4850e 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -24,6 +24,7 @@
 #ifndef TVM_RELAY_OP_NN_CONVOLUTION_H_
 #define TVM_RELAY_OP_NN_CONVOLUTION_H_
 
+#include <tvm/auto_scheduler/compute_dag.h>
 #include <tvm/support/logging.h>
 #include <tvm/tir/analysis.h>
 
@@ -369,7 +370,18 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   } else {
     // use weight to infer the conv shape.
     if (weight == nullptr) return false;
-    auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
+
+    Array<PrimExpr> wshape;
+    if (param->auto_scheduler_rewritten_layout.size() == 0) {
+      wshape = weight->shape;
+    } else {
+      // works for the default kernel layout "DHWIO"
+      ICHECK_EQ(param->kernel_layout, "DHWIO");
+      wshape = auto_scheduler::GetShapeFromRewrittenLayout(param->auto_scheduler_rewritten_layout,
+                                                           {"rd", "rh", "rw", "rc", "cc"});
+    }
+
+    wshape = trans_kernel_layout.ForwardShape(wshape);
     if (param->kernel_size.defined()) {
       ICHECK_EQ(param->kernel_size.size(), 3);
       // check the size