You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/02 13:43:18 UTC

[GitHub] TaoLv closed pull request #13530: Integrate MKLDNN Conv1d and support 3d layout

TaoLv closed pull request #13530: Integrate MKLDNN Conv1d and support 3d layout
URL: https://github.com/apache/incubator-mxnet/pull/13530
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 5a4cb29bc21..251bfb3f0e1 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -453,17 +453,10 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
 
   mkldnn::memory::dims dims;
   // These are shapes supprted by MKLDNN.
-  if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
-      || shape.ndim() == 5) {
+  if (shape.ndim() >= 1 && shape.ndim() <= 5) {
     dims.resize(shape.ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape[i];
-  } else if (shape.ndim() == 3) {
-    // If there are 3 dimensions, we'll force it to 4 dimensions.
-    dims.resize(shape.ndim() + 1);
-    dims[0] = 1;
-    for (size_t i = 0; i < shape.ndim(); i++)
-      dims[i + 1] = shape[i];
   } else {
     LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
   }
@@ -471,6 +464,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
   switch (dims.size()) {
     case 1: layout = mkldnn::memory::format::x; break;
     case 2: layout = mkldnn::memory::format::nc; break;
+    case 3: layout = mkldnn::memory::format::ncw; break;
     case 4: layout = mkldnn::memory::format::nchw; break;
     // This isn't the right layout when the data has 5 dimensions in MXNet.
     // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 305eeab2117..fb920c31ce3 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -97,9 +97,10 @@ static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
-  if (SupportMKLDNN(inputs[0])) {
+  if (SupportMKLDNNAct(param, inputs[0])) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
     MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
@@ -115,7 +116,7 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& outputs) {
   const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
   CHECK_EQ(inputs.size(), activation::GradNumInputs(param.act_type));
-  if (SupportMKLDNN(inputs[0])) {
+  if (SupportMKLDNNAct(param, inputs[0])) {
     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     // XXX: for y = relu(x), y is passed as "in_data" to Backward()
     const bool relu = param.act_type == activation::kReLU;
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc
index 440705884b3..8c64888b460 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -49,6 +49,15 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
       || param.act_type == activation::kTanh;
 }
 
+bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
+  // MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
+  if ((input.shape().ndim() < 1) ||
+      (input.shape().ndim() > 4) ||
+      (input.dtype() != mshadow::kFloat32))
+    return false;
+  return SupportMKLDNNAct(param);
+}
+
 static inline mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
   switch (param.act_type) {
     case activation::kReLU:
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 17e74094c2b..e367f42c188 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -175,10 +175,11 @@ struct ConvolutionParam;
 struct DeconvolutionParam;
 struct SoftmaxParam;
 bool SupportMKLDNNAct(const ActivationParam& param);
+bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
 bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
 bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
 bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
-}
+}  // namespace op
 
 static int GetTypeSize(int dtype) {
   int size = -1;
@@ -250,15 +251,24 @@ inline static mkldnn::memory::desc GetMemDesc(const NDArray &arr) {
 
 inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
                                                  int num_groups) {
+  auto ndim = arr.shape().ndim();
+  mkldnn::memory::dims tz = mkldnn::memory::dims{0};
   if (num_groups == 1) {
     return GetMemDesc(arr);
   } else {
-    CHECK_EQ(arr.shape().ndim(), 4U);
-    mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
-      static_cast<int>(arr.shape()[0] / num_groups),
-      static_cast<int>(arr.shape()[1]),
-      static_cast<int>(arr.shape()[2]),
-      static_cast<int>(arr.shape()[3])};
+    CHECK((ndim == 3) || (ndim == 4))
+        << "MKL-DNN weight currectly supports 3d and 4d layout";
+    const int N = 0, H = 2, W = 3, C = 1;
+    if (ndim == 3) {
+      tz = mkldnn::memory::dims{
+          num_groups, static_cast<int>(arr.shape()[N] / num_groups),
+          static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
+    } else {
+      tz = mkldnn::memory::dims{
+          num_groups, static_cast<int>(arr.shape()[N] / num_groups),
+          static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
+          static_cast<int>(arr.shape()[W])};
+    }
     return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
                                 mkldnn::memory::format::any};
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 5da55f4ca70..ccb9d7ec007 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -239,39 +239,49 @@ const mkldnn::memory *GetWeights(const NDArray &arr,
     return mem;
 
   mkldnn::memory::data_type type = get_mkldnn_type(arr.dtype());
+  mkldnn::memory::dims tz = mkldnn::memory::dims{0};
+  mkldnn::memory::format format = mkldnn::memory::format::format_undef;
   auto engine = CpuEngine::Get()->get_engine();
+  const int O = 0, I = 1, H = 2, W = 3;
   if (arr.shape().ndim() == 2) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oi};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
-    mem = arr.GetMKLDNNData(pd);
-  } else if (arr.shape().ndim() == 4 && num_groups == 1) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{
-      static_cast<int>(arr.shape()[0]), static_cast<int>(arr.shape()[1]),
-          static_cast<int>(arr.shape()[2]), static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::oihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
-    mem = arr.GetMKLDNNData(pd);
+    tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
+                              static_cast<int>(arr.shape()[I])};
+    format = mkldnn::memory::format::oi;
+  } else if (arr.shape().ndim() == 3) {
+    tz = num_groups > 1
+             ? mkldnn::memory::dims{num_groups,
+                                    static_cast<int>(arr.shape()[O] /
+                                                     num_groups),
+                                    static_cast<int>(arr.shape()[I]),
+                                    static_cast<int>(arr.shape()[H])}
+             : mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
+                                    static_cast<int>(arr.shape()[I]),
+                                    static_cast<int>(arr.shape()[H])};
+    format = num_groups > 1 ? mkldnn::memory::format::goiw
+                            : mkldnn::memory::format::oiw;
   } else if (arr.shape().ndim() == 4) {
-    mkldnn::memory::dims tz = mkldnn::memory::dims{ num_groups,
-      static_cast<int>(arr.shape()[0] / num_groups),
-      static_cast<int>(arr.shape()[1]),
-      static_cast<int>(arr.shape()[2]),
-      static_cast<int>(arr.shape()[3])};
-    mkldnn::memory::desc md =
-        mkldnn::memory::desc{tz, type, mkldnn::memory::format::goihw};
-    mkldnn::memory::primitive_desc pd =
-        mkldnn::memory::primitive_desc{md, engine};
-    mem = arr.GetMKLDNNData(pd);
+    tz = num_groups > 1
+             ? mkldnn::memory::dims{num_groups,
+                                    static_cast<int>(arr.shape()[O] /
+                                                     num_groups),
+                                    static_cast<int>(arr.shape()[I]),
+                                    static_cast<int>(arr.shape()[H]),
+                                    static_cast<int>(arr.shape()[W])}
+             : mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
+                                    static_cast<int>(arr.shape()[I]),
+                                    static_cast<int>(arr.shape()[H]),
+                                    static_cast<int>(arr.shape()[W])};
+    format = num_groups > 1 ? mkldnn::memory::format::goihw
+                            : mkldnn::memory::format::oihw;
   } else {
     LOG(FATAL) << "The weight array has an unsupported number of dimensions";
     return nullptr;
   }
+  mkldnn::memory::desc md =
+      mkldnn::memory::desc{tz, type, format};
+  mkldnn::memory::primitive_desc pd =
+      mkldnn::memory::primitive_desc{md, engine};
+  mem = arr.GetMKLDNNData(pd);
   if (mem == nullptr)
     mem = arr.GetMKLDNNDataReorder(target_pd);
   if (mem->get_primitive_desc() == target_pd) return mem;
@@ -285,6 +295,7 @@ mkldnn_memory_format_t GetDefaultFormat(int num_dims) {
   switch (num_dims) {
     case 1: return mkldnn_x;
     case 2: return mkldnn_nc;
+    case 3: return mkldnn_ncw;
     case 4: return mkldnn_nchw;
     case 5: return mkldnn_goihw;
     default:
@@ -301,6 +312,30 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
       return mkldnn_oi;
     else
       return desc.data.format;
+  } else if (desc.data.ndims == 3) {
+    switch (desc.data.format) {
+      case mkldnn_ncw:
+      case mkldnn_nwc:
+      case mkldnn_nCw8c:
+      case mkldnn_nCw16c:
+        return mkldnn_ncw;
+      case mkldnn_oiw:
+      case mkldnn_wio:
+      case mkldnn_Owi8o:
+      case mkldnn_OIw8i8o:
+      case mkldnn_OIw8o8i:
+      case mkldnn_OIw16i16o:
+      case mkldnn_OIw16o16i:
+      case mkldnn_Oiw16o:
+      case mkldnn_Owi16o:
+      case mkldnn_OIw8i16o2i:
+      case mkldnn_OIw8o16i2o:
+      case mkldnn_IOw16o16i:
+        return mkldnn_oiw;
+      default:
+        LOG(FATAL) << "Unknown MKLDNN format for 3 dimensions: " << desc.data.format;
+        return mkldnn_format_undef;
+    }
   } else if (desc.data.ndims == 4) {
     switch (desc.data.format) {
       case mkldnn_nchw:
@@ -329,6 +364,18 @@ mkldnn_memory_format_t GetDefaultFormat(const mkldnn::memory::desc &desc) {
       case mkldnn_Ohwi16o:
       case mkldnn_OhIw16o4i:
         return mkldnn_oihw;
+      case mkldnn_goiw:
+      case mkldnn_gOwi8o:
+      case mkldnn_gOIw8o8i:
+      case mkldnn_gOIw8i8o:
+      case mkldnn_gOIw16i16o:
+      case mkldnn_gOIw16o16i:
+      case mkldnn_gOiw16o:
+      case mkldnn_gOwi16o:
+      case mkldnn_gOIw8i16o2i:
+      case mkldnn_gOIw8o16i2o:
+      case mkldnn_gIOw16o16i:
+        return mkldnn_goiw;
       default:
         LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
         return mkldnn_format_undef;
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index dd1f3ec07d7..7f423ce4524 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -37,9 +37,12 @@ namespace op {
 DMLC_REGISTER_PARAMETER(MKLDNNConvParam);
 
 bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
-  if (params.kernel.ndim() != 2)
+  if ((params.kernel.ndim() != 1) &&
+      (params.kernel.ndim() != 2))
     return false;
-  return SupportMKLDNNQuantize(input.dtype()) && input.shape().ndim() == 4;
+  return SupportMKLDNNQuantize(input.dtype()) &&
+         ((input.shape().ndim() == 3) ||
+          (input.shape().ndim() == 4));
 }
 
 mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
@@ -51,15 +54,26 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
   auto weight_md = GetWeightDesc(weights, param.conv_param.num_group);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
-  CHECK_GE(param.conv_param.stride.ndim(), 2U);
-  CHECK_GE(param.conv_param.pad.ndim(), 2U);
-  CHECK_GE(param.conv_param.dilate.ndim(), 2U);
-  mkldnn::memory::dims strides{0, 0};
-  strides[0] = param.conv_param.stride[0];
-  strides[1] = param.conv_param.stride[1];
-  mkldnn::memory::dims padding{0, 0};
-  padding[0] = param.conv_param.pad[0];
-  padding[1] = param.conv_param.pad[1];
+  mkldnn::memory::dims strides(param.conv_param.kernel.ndim());
+  mkldnn::memory::dims padding(param.conv_param.kernel.ndim());
+  if (param.conv_param.kernel.ndim() == 1) {
+    CHECK_GE(param.conv_param.stride.ndim(), 1U);
+    CHECK_GE(param.conv_param.pad.ndim(), 1U);
+    CHECK_GE(param.conv_param.dilate.ndim(), 1U);
+    strides[0] = param.conv_param.stride[0];
+    padding[0] = param.conv_param.pad[0];
+  } else if (param.conv_param.kernel.ndim() == 2) {
+    CHECK_GE(param.conv_param.stride.ndim(), 2U);
+    CHECK_GE(param.conv_param.pad.ndim(), 2U);
+    CHECK_GE(param.conv_param.dilate.ndim(), 2U);
+    strides[0] = param.conv_param.stride[0];
+    strides[1] = param.conv_param.stride[1];
+    padding[0] = param.conv_param.pad[0];
+    padding[1] = param.conv_param.pad[1];
+  } else {
+    LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size "
+               << param.conv_param.kernel.ndim() << ", supporting only 1 or 2.";
+  }
   mkldnn::primitive_attr attr;
   mkldnn::post_ops ops;
   if (param.mkldnn_param.with_relu) {
@@ -113,9 +127,17 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(
     }
     return conv_pd;
   } else {
-    mkldnn::memory::dims dilates{0, 0};
-    dilates[0] = param.conv_param.dilate[0] - 1;
-    dilates[1] = param.conv_param.dilate[1] - 1;
+    mkldnn::memory::dims dilates(param.conv_param.kernel.ndim());
+    if (param.conv_param.dilate.ndim() == 1) {
+      dilates[0] = param.conv_param.dilate[0] - 1;
+    } else if (param.conv_param.dilate.ndim() == 2) {
+      dilates[0] = param.conv_param.dilate[0] - 1;
+      dilates[1] = param.conv_param.dilate[1] - 1;
+    } else {
+      LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
+                 << param.conv_param.dilate.ndim()
+                 << ", supporting only 1 or 2.";
+    }
     if (bias == nullptr) {
       mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct,
           data_md, weight_md, out_md, strides, dilates, padding, padding,
@@ -151,15 +173,26 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
   auto weight_md = GetWeightDesc(weights, param.num_group);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
-  CHECK_GE(param.stride.ndim(), 2U);
-  CHECK_GE(param.pad.ndim(), 2U);
-  CHECK_GE(param.dilate.ndim(), 2U);
-  mkldnn::memory::dims strides{0, 0};
-  strides[0] = param.stride[0];
-  strides[1] = param.stride[1];
-  mkldnn::memory::dims padding{0, 0};
-  padding[0] = param.pad[0];
-  padding[1] = param.pad[1];
+  mkldnn::memory::dims strides(param.kernel.ndim());
+  mkldnn::memory::dims padding(param.kernel.ndim());
+  if (param.kernel.ndim() == 1) {
+    CHECK_GE(param.stride.ndim(), 1U);
+    CHECK_GE(param.pad.ndim(), 1U);
+    CHECK_GE(param.dilate.ndim(), 1U);
+    strides[0] = param.stride[0];
+    padding[0] = param.pad[0];
+  } else if (param.kernel.ndim() == 2) {
+    CHECK_GE(param.stride.ndim(), 2U);
+    CHECK_GE(param.pad.ndim(), 2U);
+    CHECK_GE(param.dilate.ndim(), 2U);
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
+    padding[0] = param.pad[0];
+    padding[1] = param.pad[1];
+  } else {
+    LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
+               << ", supporting only 1 or 2.";
+  }
 
   // MKL-DNN introduced padded formats since 0.15 which require more memory
   // for computation compared with the actual tensor size. Currently, MKL-DNN
@@ -177,9 +210,16 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
     }
     return conv_pd;
   } else {
-    mkldnn::memory::dims dilates{0, 0};
-    dilates[0] = param.dilate[0] - 1;
-    dilates[1] = param.dilate[1] - 1;
+    mkldnn::memory::dims dilates(param.kernel.ndim());
+    if (param.dilate.ndim() == 1) {
+      dilates[0] = param.dilate[0] - 1;
+    } else if (param.dilate.ndim() == 2) {
+      dilates[0] = param.dilate[0] - 1;
+      dilates[1] = param.dilate[1] - 1;
+    } else {
+      LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
+                 << param.dilate.ndim() << ", supporting only 1 or 2.";
+    }
     mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
         data_md, weight_md, out_md, strides, dilates, padding, padding,
         mkldnn::padding_kind::zero);
@@ -201,15 +241,26 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
   auto weight_md = GetWeightDesc(weights, param.num_group);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
-  CHECK_GE(param.stride.ndim(), 2U);
-  CHECK_GE(param.pad.ndim(), 2U);
-  CHECK_GE(param.dilate.ndim(), 2U);
-  mkldnn::memory::dims strides{0, 0};
-  strides[0] = param.stride[0];
-  strides[1] = param.stride[1];
-  mkldnn::memory::dims padding{0, 0};
-  padding[0] = param.pad[0];
-  padding[1] = param.pad[1];
+  mkldnn::memory::dims strides(param.kernel.ndim());
+  mkldnn::memory::dims padding(param.kernel.ndim());
+  if (param.kernel.ndim() == 1) {
+    CHECK_GE(param.stride.ndim(), 1U);
+    CHECK_GE(param.pad.ndim(), 1U);
+    CHECK_GE(param.dilate.ndim(), 1U);
+    strides[0] = param.stride[0];
+    padding[0] = param.pad[0];
+  } else if (param.kernel.ndim() == 2) {
+    CHECK_GE(param.stride.ndim(), 2U);
+    CHECK_GE(param.pad.ndim(), 2U);
+    CHECK_GE(param.dilate.ndim(), 2U);
+    strides[0] = param.stride[0];
+    strides[1] = param.stride[1];
+    padding[0] = param.pad[0];
+    padding[1] = param.pad[1];
+  } else {
+    LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
+               << ", supporting only 1 or 2.";
+  }
 
   // MKL-DNN introduced padded formats since 0.15 which require more memory
   // for computation compared with the actual tensor size. Currently, MKL-DNN
@@ -239,9 +290,16 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
     }
     return conv_pd;
   } else {
-    mkldnn::memory::dims dilates{0, 0};
-    dilates[0] = param.dilate[0] - 1;
-    dilates[1] = param.dilate[1] - 1;
+    mkldnn::memory::dims dilates(param.kernel.ndim());
+    if (param.dilate.ndim() == 1) {
+      dilates[0] = param.dilate[0] - 1;
+    } else if (param.dilate.ndim() == 2) {
+      dilates[0] = param.dilate[0] - 1;
+      dilates[1] = param.dilate[1] - 1;
+    } else {
+      LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
+                 << param.dilate.ndim() << ", supporting only 1 or 2.";
+    }
     if (bias == nullptr) {
       mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
           data_md, weight_md, out_md, strides, dilates, padding, padding,
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index dfa98d1f5ee..65e0e5c4b27 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -626,8 +626,12 @@ std::vector<std::pair<int, int>> SgMKLDNNConvInplaceOption(
 }
 
 nnvm::NodePtr SgMKLDNNConvQuantizedOp(const NodeAttrs& attrs) {
+  auto const &param = nnvm::get<MKLDNNConvFusionParam>(attrs.parsed);
   nnvm::NodePtr node = nnvm::Node::Create();
   node->attrs.op = Op::Get("_sg_mkldnn_conv");
+  CHECK_EQ(param.full_conv_param.conv_param.kernel.ndim(), 2U)
+      << "Quantized Convolution of MKL-DNN only supports 2D kernel currently."
+      <<  "Please exclude this layer from the quantized model.";
   node->attrs.name = "quantized_" + attrs.name;
   node->attrs.dict = attrs.dict;
   node->attrs.dict["quantized"] = "true";
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 09157396f83..a895594ce28 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1602,33 +1602,33 @@ def check_batchnorm_training(stype):
 def test_convolution_grouping():
     for dim in [1, 2, 3]:
         num_filter = 4
-        num_group = 2
-        kernel = (3,) * dim
-        shape = (1, 4) + (9,) * dim
-
-        x = mx.sym.Variable('x')
-        w = mx.sym.Variable('w')
-        b = mx.sym.Variable('b')
-        y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel)
-        xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1)
-        wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0)
-        bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0)
-        y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i],
-                                                num_filter=num_filter//num_group, kernel=kernel)
-                           for i in range(num_group)])
-
-        exe1 = y1.simple_bind(default_context(), x=shape)
-        exe2 = y2.simple_bind(default_context(), x=shape, w=(num_filter, shape[1]//num_group) + kernel, b=(num_filter,))
-        for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays):
-            arr1[:] = np.float32(np.random.normal(size=arr1.shape))
-            arr2[:] = arr1
-        exe1.forward(is_train=True)
-        exe1.backward(exe1.outputs[0])
-        exe2.forward(is_train=True)
-        exe2.backward(exe2.outputs[0])
-
-        for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
-            np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3)
+        for num_group in [1, 2]:
+            kernel = (3,) * dim
+            shape = (1, 4) + (9,) * dim
+
+            x = mx.sym.Variable('x')
+            w = mx.sym.Variable('w')
+            b = mx.sym.Variable('b')
+            y1 = mx.sym.Convolution(data=x, weight=w, bias=b, num_filter=num_filter, num_group=num_group, kernel=kernel)
+            xslice = mx.sym.SliceChannel(data=x, num_outputs=num_group, axis=1)
+            wslice = mx.sym.SliceChannel(data=w, num_outputs=num_group, axis=0)
+            bslice = mx.sym.SliceChannel(data=b, num_outputs=num_group, axis=0)
+            y2 = mx.sym.Concat(*[mx.sym.Convolution(data=xslice[i], weight=wslice[i], bias=bslice[i],
+                                                    num_filter=num_filter//num_group, kernel=kernel)
+                            for i in range(num_group)])
+
+            exe1 = y1.simple_bind(default_context(), x=shape)
+            exe2 = y2.simple_bind(default_context(), x=shape, w=(num_filter, shape[1]//num_group) + kernel, b=(num_filter,))
+            for arr1, arr2 in zip(exe1.arg_arrays, exe2.arg_arrays):
+                arr1[:] = np.float32(np.random.normal(size=arr1.shape))
+                arr2[:] = arr1
+            exe1.forward(is_train=True)
+            exe1.backward(exe1.outputs[0])
+            exe2.forward(is_train=True)
+            exe2.backward(exe2.outputs[0])
+
+            for arr1, arr2 in zip(exe1.outputs + exe1.grad_arrays, exe2.outputs + exe2.grad_arrays):
+                np.testing.assert_allclose(arr1.asnumpy(), arr2.asnumpy(), rtol=1e-3, atol=1e-3)
 
 
 @unittest.skip("Flaky test https://github.com/apache/incubator-mxnet/issues/12203")
@@ -6772,7 +6772,7 @@ def get_output_names_callback(name, arr):
 
 @with_seed()
 def test_activation():
-    shape=(9, 10)
+    shapes = [(9,), (9, 10), (9, 10, 10), (1, 9, 10, 10)]
     dtype_l = [np.float64, np.float32, np.float16]
     rtol_l = [1e-7, 1e-6, 1e-2]
     atol_l = [1e-7, 1e-6, 1e-2]
@@ -6803,17 +6803,19 @@ def test_activation():
     }
     # Loop over operators
     for name, op in unary_ops.items():
-        # Loop over dtype's
-        for ind in range(len(dtype_l)):
-            dtype = dtype_l[ind]
-            rtol = rtol_l[ind]
-            atol = atol_l[ind]
-            compare_forw_backw_unary_op(
-                name, op[0], op[1], op[2], shape, op[3], op[4], rtol, atol,
-                dtype)
-        # Finite difference testing
-        finite_diff_unary_op(
-            name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)
+        # Loop over shapes
+        for shape in shapes:
+            # Loop over dtype's
+            for ind in range(len(dtype_l)):
+                dtype = dtype_l[ind]
+                rtol = rtol_l[ind]
+                atol = atol_l[ind]
+                compare_forw_backw_unary_op(
+                    name, op[0], op[1], op[2], shape, op[3], op[4], rtol, atol,
+                    dtype)
+            # Finite difference testing
+            finite_diff_unary_op(
+                name, op[0], shape, op[3], op[4], rtol_fd, atol_fd, num_eps)
 
 @with_seed()
 def test_ravel():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services