You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/06/12 23:14:50 UTC
[tvm] branch main updated: [Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)
This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8341e33d05 [Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)
8341e33d05 is described below
commit 8341e33d05868b7bb8496c913679b7951836f3b9
Author: WANG Zihan <wz...@126.com>
AuthorDate: Mon Jun 13 07:14:43 2022 +0800
[Bugfix] Shape inference of weight for grouped `nn.conv3d` (#11681)
* Fix `nn.conv3d` weight shape inference.
* Add test for conv3d type inference with groups.
---
src/relay/op/nn/convolution.cc | 14 ++------------
tests/python/relay/test_op_level2.py | 7 +++++++
2 files changed, 9 insertions(+), 12 deletions(-)
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 0c882589e9..a6f6390b21 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -438,18 +438,8 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 3);
ICHECK_EQ(param->dilation.size(), 3);
- Array<IndexExpr> wshape;
- tvm::tir::ExprDeepEqual expr_equal;
-
- if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) {
- // infer weight's shape for depthwise convolution
- wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0],
- param->kernel_size[1], param->kernel_size[2]}};
- } else {
- wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0],
- param->kernel_size[1], param->kernel_size[2]}};
- }
-
+ Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups),
+ param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index f547565464..dd6a54b959 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -522,6 +522,13 @@ def test_conv3d_infer_type():
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32")
+ # Infer with groups
+ x = relay.var("x", relay.TensorType((1, 16, 224, 224, 224), "float32"))
+ w = relay.var("w", relay.TensorType((4, 4, 1, 1, 1), "float32"))
+ y = relay.nn.conv3d(x, w, groups=4, kernel_size=(1, 1, 1), channels=4)
+ yy = run_infer_type(y)
+ assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 224), "float32")
+
@tvm.testing.uses_gpu
def test_conv3d_run():