You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by zh...@apache.org on 2021/04/19 20:03:23 UTC

[incubator-mxnet] branch master updated: [BUGFIX] Add check to make sure num_group is non-zero (#20186)

This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5da68f7  [BUGFIX] Add check to make sure num_group is non-zero (#20186)
5da68f7 is described below

commit 5da68f725d80e3f1102ec4b0059005f7cf47c886
Author: herewj <41...@users.noreply.github.com>
AuthorDate: Tue Apr 20 04:01:13 2021 +0800

    [BUGFIX] Add check to make sure num_group is non-zero (#20186)
    
    * add check for group not equal zero
    
    * num_group in convolution must be positive
---
 src/operator/nn/convolution.cc | 6 ++++++
 1 file changed, 6 insertions(+)

diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 556918a..cbfadf9 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -99,6 +99,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
     // 1d conv
     CHECK_EQ(dshp.ndim(), 3U) << "Input data should be 3D in batch-num_filter-x";
     Shape<3> dshape = ConvertLayout(dshp.get<3>(), param_.layout.value(), kNCW);
+    CHECK_GT(param_.num_group, 0U) \
+      << "Range only supports num_group > 0, received " << param_.num_group;
     Shape<3> wshape = Shape3(param_.num_filter / param_.num_group,
         mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
         param_.kernel[0]);
@@ -149,6 +151,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(dshp.ndim(), 4U) \
       << "Input data should be 4D in batch-num_filter-y-x";
     Shape<4> dshape = ConvertLayout(dshp.get<4>(), param_.layout.value(), kNCHW);
+    CHECK_GT(param_.num_group, 0U) \
+      << "Range only supports num_group > 0, received " << param_.num_group;
     Shape<4> wshape = Shape4(param_.num_filter / param_.num_group,
         mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
         param_.kernel[0], param_.kernel[1]);
@@ -208,6 +212,8 @@ static bool ConvolutionShape(const nnvm::NodeAttrs& attrs,
     CHECK_EQ(dshp.ndim(), 5U) \
       << "Input data should be 5D in batch-num_filter-depth-y-x";
     Shape<5> dshape = ConvertLayout(dshp.get<5>(), param_.layout.value(), kNCDHW);
+    CHECK_GT(param_.num_group, 0U) \
+      << "Range only supports num_group > 0, received " << param_.num_group;
     Shape<5> wshape = Shape5(param_.num_filter / param_.num_group,
         mxnet::dim_size_is_known(dshape, 1) ? dshape[1] / param_.num_group : -1,
         param_.kernel[0], param_.kernel[1], param_.kernel[2]);