You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/12 23:14:50 UTC

[tvm] branch main updated: [Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8341e33d05 [Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)
8341e33d05 is described below

commit 8341e33d05868b7bb8496c913679b7951836f3b9
Author: WANG Zihan <wz...@126.com>
AuthorDate: Mon Jun 13 07:14:43 2022 +0800

    [Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)
    
    * Fix `nn.conv3d` weight shape inference.
    
    * Add test for conv3d type inference with groups.
---
 src/relay/op/nn/convolution.cc       | 14 ++------------
 tests/python/relay/test_op_level2.py |  7 +++++++
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 0c882589e9..a6f6390b21 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -438,18 +438,8 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
   if (param->kernel_size.defined() && param->channels.defined()) {
     ICHECK_EQ(param->kernel_size.size(), 3);
     ICHECK_EQ(param->dilation.size(), 3);
-    Array<IndexExpr> wshape;
-    tvm::tir::ExprDeepEqual expr_equal;
-
-    if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) {
-      // infer weight's shape for depthwise convolution
-      wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0],
-                 param->kernel_size[1], param->kernel_size[2]}};
-    } else {
-      wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0],
-                 param->kernel_size[1], param->kernel_size[2]}};
-    }
-
+    Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups),
+                             param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});
     wshape = trans_kernel_layout.BackwardShape(wshape);
     channels = param->channels;
     dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index f547565464..dd6a54b959 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -522,6 +522,13 @@ def test_conv3d_infer_type():
     yy = run_infer_type(y)
     assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32")
 
+    # Infer with groups
+    x = relay.var("x", relay.TensorType((1, 16, 224, 224, 224), "float32"))
+    w = relay.var("w", relay.TensorType((4, 4, 1, 1, 1), "float32"))
+    y = relay.nn.conv3d(x, w, groups=4, kernel_size=(1, 1, 1), channels=4)
+    yy = run_infer_type(y)
+    assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 224), "float32")
+
 
 @tvm.testing.uses_gpu
 def test_conv3d_run():