You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/07/08 02:08:01 UTC

[incubator-mxnet] branch master updated: fix fp32 flatten issue (#15351)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 091fece  fix fp32 flatten issue (#15351)
091fece is described below

commit 091fece5431b43c931568d279c1dd9d664318c36
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Mon Jul 8 10:07:37 2019 +0800

    fix fp32 flatten issue (#15351)
    
    * Fix flatten issue before slice op
    
    * fix cpplint
    
    * address comments
    
    * retrigger CI
    
    * trigger CI
    
    * retrigger CI
    
    * use SupportMKLDNNReshape and update operator list
---
 docs/tutorials/mkldnn/operator_list.md      |   2 +
 src/operator/nn/mkldnn/mkldnn_base-inl.h    |   2 +
 src/operator/nn/mkldnn/mkldnn_flatten.cc    |  87 +++++++++++++
 src/operator/nn/mkldnn/mkldnn_ops-inl.h     |   9 +-
 src/operator/nn/mkldnn/mkldnn_reshape-inl.h |  68 ++++++++++
 src/operator/nn/mkldnn/mkldnn_reshape.cc    | 185 +++++++++++++---------------
 src/operator/tensor/matrix_op.cc            |  20 ++-
 tests/python/gpu/test_operator_gpu.py       |  14 +++
 tests/python/mkl/test_mkldnn.py             |  20 +++
 9 files changed, 296 insertions(+), 111 deletions(-)

diff --git a/docs/tutorials/mkldnn/operator_list.md b/docs/tutorials/mkldnn/operator_list.md
index 4958f8d..0ef0f29 100644
--- a/docs/tutorials/mkldnn/operator_list.md
+++ b/docs/tutorials/mkldnn/operator_list.md
@@ -44,6 +44,8 @@ To help users understanding MKL-DNN backend better, the following table summariz
 | **elemwise_add**   | 1D-4D input                | Y                        | Y              | Y              |
 | **Concat**         | 1D-4D input                | Y                        | Y              | Y              |
 | **slice**          | 1D-4D input                | N                        | Y              | N              |
+| **Reshape**        | 1D-4D input                | N                        | Y              | N              |
+| **Flatten**        | 1D-4D input                | N                        | Y              | N              |
 | **Quantization**   | 1D-4D input                | N                        | N              | Y              |
 | **Dequantization** | 1D-4D input                | N                        | N              | Y              |
 | **Requantization** | 1D-4D input                | N                        | N              | Y              |
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 5670983..e01b7b1 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -176,6 +176,7 @@ struct DeconvolutionParam;
 struct SoftmaxParam;
 struct SoftmaxOutputParam;
 struct TransposeParam;
+struct ReshapeParam;
 bool SupportMKLDNNAct(const ActivationParam& param);
 bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
 bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
@@ -184,6 +185,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
 bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
 bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
 bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
+bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
 }  // namespace op
 
 static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc
new file mode 100644
index 0000000..fdc02f9
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_flatten.cc
+ * \brief Implement flatten operator by using mkldnn reorder primitive
+ * \author Wuxun Zhang
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "mkldnn_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNFlattenFwd : public MKLDNNReshapeFwd {
+ public:
+  explicit MKLDNNFlattenFwd(const OpReqType &req,
+                            const NDArray &input,
+                            const NDArray &output)
+    : MKLDNNReshapeFwd(req, input, output) {}
+};
+
+static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req,
+                                           const NDArray &input,
+                                           const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature,
+                                         MKLDNNFlattenFwd, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature,
+                                            MKLDNNFlattenFwd, OpHash> fwds;
+#endif
+  OpSignature key;
+  key.AddSign(req);
+  key.AddSign(input);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNFlattenFwd fwd(req, input, output);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &input,
+                          const OpReqType &req,
+                          const NDArray &output) {
+  if (req == kNullOp) return;
+  CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
+
+  auto fwd = GetFlattenForward(req, input, output);
+  auto ws_size = fwd.GetWorkspaceSize();
+  void* ws_ptr = nullptr;
+  if (ws_size) {
+    mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+    mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
+      .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
+    ws_ptr = reinterpret_cast<void*>(ws.dptr_);
+  }
+
+  fwd.Execute(input, output, ws_ptr);
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 2699a02..502abff 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -119,12 +119,17 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
                             const OpReqType &req,
                             const NDArray &output);
 
-void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs,
+void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
                           const OpContext &ctx,
-                          const NDArray &data,
+                          const NDArray &input,
                           const OpReqType &req,
                           const NDArray &output);
 
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &input,
+                          const OpReqType &req,
+                          const NDArray &output);
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
new file mode 100644
index 0000000..63e367b
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file mkldnn_reshape-inl.h
+ * \brief Function definition of mkldnn reshape operator
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+#include <vector>
+#include "mkldnn_base-inl.h"
+#include "../../tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNReshapeFwd {
+ protected:
+  std::shared_ptr<mkldnn::memory> data_;
+  std::shared_ptr<mkldnn::memory> out_;
+  std::shared_ptr<mkldnn::memory> temp_;
+  std::vector<mkldnn::primitive> prims_;
+  bool needInvalidateInput = false;
+
+ public:
+  MKLDNNReshapeFwd(const OpReqType &req,
+                   const NDArray &input,
+                   const NDArray &output);
+  int GetWorkspaceSize();
+  void SetNewMem(const NDArray &input,
+                 const NDArray &output,
+                 void* workspace = nullptr);
+  void Execute(const NDArray &input,
+               const NDArray &output,
+               void* workspace = nullptr);
+};
+
+typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
+MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
+                                    const OpReqType &req,
+                                    const NDArray &input,
+                                    const NDArray &output);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index 4f1d67a..063c85d 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -26,7 +26,7 @@
 #if MXNET_USE_MKLDNN == 1
 
 #include <mkldnn.hpp>
-#include "../../tensor/matrix_op-inl.h"
+#include "mkldnn_reshape-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -43,117 +43,106 @@ bool SupportMKLDNNReshape(const ReshapeParam &param,
   return true;
 }
 
-typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
-
-class MKLDNNReshapeForward {
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> out_;
-  std::shared_ptr<mkldnn::memory> temp_;
-  std::vector<mkldnn::primitive> prims_;
-
-  bool needInvalidateInput = false;
-
- public:
-  MKLDNNReshapeForward(const ReshapeParam &param,
-                       const OpReqType &req,
-                       const NDArray &input,
-                       const NDArray &output) {
-    auto engine = CpuEngine::Get()->get_engine();
-
-    // data_
-    auto in_mem = input.GetMKLDNNData();
-    auto in_pd = in_mem->get_primitive_desc();
-    data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
-
-    // temp_
-    auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
-    auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
-    auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
-    auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
-    auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
-    temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
-
-    // destination
-    out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
-
-    if (req == kWriteInplace) {
-      // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
-      // default layout and copy from the temporal buffer back to output buffer which has the same
-      // address with input buffer.
-      // If the input has default layout, then nothing need to do.
-      if (input.IsMKLDNNData()) {
-        prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
-        prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy back
-        needInvalidateInput = true;
-      }
-    } else if (req == kWriteTo) {
-      if (input.IsMKLDNNData()) {
-        prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
-        prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy to the output buffer
-        needInvalidateInput = false;
-      } else {
-        prims_.push_back(mkldnn::reorder(*data_, *out_));    // copy directly from input to output
-        needInvalidateInput = false;
-      }
+MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
+                                   const NDArray &input,
+                                   const NDArray &output) {
+  auto engine = CpuEngine::Get()->get_engine();
+
+  // data_
+  auto in_mem = input.GetMKLDNNData();
+  auto in_pd = in_mem->get_primitive_desc();
+  data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
+
+  // temp_
+  auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
+  auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
+  auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
+  auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
+  auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
+  temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+
+  // destination
+  out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+
+  if (req == kWriteInplace) {
+    // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
+    // default layout and copy from the temporal buffer back to output buffer which has the same
+    // address with input buffer.
+    // If the input has default layout, then nothing need to do.
+    if (input.IsMKLDNNData()) {
+      prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
+      prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy back
+      needInvalidateInput = true;
+    }
+  } else if (req == kWriteTo) {
+    if (input.IsMKLDNNData()) {
+      prims_.push_back(mkldnn::reorder(*data_, *temp_));   // reorder to default
+      prims_.push_back(mkldnn::reorder(*temp_, *out_));    // copy to the output buffer
+      needInvalidateInput = false;
     } else {
-      LOG(FATAL) << "not supported req type: " << req;
+      prims_.push_back(mkldnn::reorder(*data_, *out_));    // copy directly from input to output
+      needInvalidateInput = false;
     }
+  } else {
+    LOG(FATAL) << "not supported req type: " << req;
   }
+}
 
-  int GetWorkspaceSize() {
-    return temp_ ? temp_->get_primitive_desc().get_size() : 0;
-  }
+int MKLDNNReshapeFwd::GetWorkspaceSize() {
+  return temp_ ? temp_->get_primitive_desc().get_size() : 0;
+}
 
-  void SetNewMem(const NDArray &input, const NDArray &output, void* workspace = nullptr) {
-    if (input.IsMKLDNNData()) {
-      this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
-    } else {
-      MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
-        this->data_->set_data_handle(input.data().dptr<DTYPE>());
-      })
-    }
+void MKLDNNReshapeFwd::SetNewMem(const NDArray &input,
+                                 const NDArray &output,
+                                 void* workspace) {
+  if (input.IsMKLDNNData()) {
+    this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
+  } else {
+    MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
+      this->data_->set_data_handle(input.data().dptr<DTYPE>());
+    })
+  }
 
-    if (output.IsMKLDNNData()) {
-      this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
-    } else {
-      MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
-        this->out_->set_data_handle(output.data().dptr<DTYPE>());
-      })
-    }
+  if (output.IsMKLDNNData()) {
+    this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
+  } else {
+    MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
+      this->out_->set_data_handle(output.data().dptr<DTYPE>());
+    })
+  }
 
-    if (workspace) {
-      this->temp_->set_data_handle(workspace);
-    }
+  if (workspace) {
+    this->temp_->set_data_handle(workspace);
   }
+}
 
-  void Execute(const NDArray &input,
-               const NDArray &output,
-               void* workspace = nullptr) {
-    // set memory handles
-    SetNewMem(input, output, workspace);
-    // register primitives
-    auto stream = MKLDNNStream::Get();
-    for (auto &v : this->prims_) {
-      stream->RegisterPrim(v);
-    }
-    stream->Submit();
-    // invalidate mkldnn memory in input
-    if (needInvalidateInput) {
-      const_cast<NDArray &>(input).InvalidateMKLDNNData();
-    }
+void MKLDNNReshapeFwd::Execute(const NDArray &input,
+                               const NDArray &output,
+                               void* workspace) {
+  // set memory handles
+  SetNewMem(input, output, workspace);
+  // register primitives
+  auto stream = MKLDNNStream::Get();
+  for (auto &v : this->prims_) {
+    stream->RegisterPrim(v);
   }
-};
+  stream->Submit();
+  // invalidate mkldnn memory in input
+  if (needInvalidateInput) {
+    const_cast<NDArray &>(input).InvalidateMKLDNNData();
+  }
+}
 
-static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param,
-                                               const OpReqType &req,
-                                               const NDArray &input,
-                                               const NDArray &output) {
+MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
+                                    const OpReqType &req,
+                                    const NDArray &input,
+                                    const NDArray &output) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNReshapeSignature,
-                                         MKLDNNReshapeForward, OpHash> fwds;
+                                         MKLDNNReshapeFwd, OpHash> fwds;
 #else
   static MX_THREAD_LOCAL std::unordered_map<MKLDNNReshapeSignature,
-                                            MKLDNNReshapeForward, OpHash> fwds;
+                                            MKLDNNReshapeFwd, OpHash> fwds;
 #endif
   MKLDNNReshapeSignature key(param);
   key.AddSign(req);
@@ -162,7 +151,7 @@ static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param,
 
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    MKLDNNReshapeForward fwd(param, req, input, output);
+    MKLDNNReshapeFwd fwd(req, input, output);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b4abc9f..c2bcb29 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -111,12 +111,13 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
                                 const std::vector<NDArray>& inputs,
                                 const std::vector<OpReqType>& req,
                                 const std::vector<NDArray>& outputs) {
+  const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed);
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
   // If inputs are supposed to be in MKLDNN format and
   // MKLDNNsupport the data type or the shape. Then convert
   // it to the output format and shape
-  if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape())) {
+  if (SupportMKLDNNReshape(param, inputs[0])) {
     MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
     return;
   }
@@ -233,12 +234,9 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(inputs.size(), 1U);
   CHECK_EQ(outputs.size(), 1U);
 #if MXNET_USE_MKLDNN == 1
-  if (inputs[0].IsMKLDNNData()) {
-    MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]);
-    // If the output is a special MKLDNN layout and the number of dimensions
-    // is larger than 2, we should use the default layout.
-    if (outputs[0].IsMKLDNNData() && inputs[0].shape().ndim() > 2)
-      const_cast<NDArray &>(outputs[0]).Reorder2Default();
+  auto data_ndim = inputs[0].shape().ndim();
+  if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) {
+    MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]);
     return;
   } else {
     // This happens if inputs are supposed to be in MKLDNN format
@@ -252,10 +250,10 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
 
 #if MXNET_USE_MKLDNN == 1
 static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
-                                   const int dev_mask,
-                                   DispatchMode* dispatch_mode,
-                                   std::vector<int> *in_attrs,
-                                   std::vector<int> *out_attrs) {
+                                      const int dev_mask,
+                                      DispatchMode* dispatch_mode,
+                                      std::vector<int> *in_attrs,
+                                      std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 1);
   CHECK_EQ(out_attrs->size(), 1);
   return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 064f783..5b4f81d 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1130,6 +1130,20 @@ def test_pooling_full_2d():
 
 
 @with_seed()
+def test_flatten_slice_after_conv():
+    ctx_list = []
+
+    data = mx.sym.Variable('conv_data')
+    conv = mx.symbol.Convolution(data=data, name='conv', num_filter=16, kernel=(3,3), stride=(1,1))
+    flatten = mx.symbol.flatten(data=conv)
+    slice_sym = mx.symbol.slice(data=flatten, begin=0, end=1)
+
+    ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}},
+                {'ctx': mx.cpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}}]
+    check_consistency(slice_sym, ctx_list)
+
+
+@with_seed()
 def test_global_pooling():
     def test_1d_pooling(pool_type, p_value=2):
         data = (2, 3, 20)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 662edcf..3e623b5 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -233,6 +233,26 @@ def test_slice_reshape_before_conv():
     mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6)
 
 
+@with_seed()
+def test_flatten_slice_after_conv():
+    data = mx.symbol.Variable('data')
+    weight = mx.symbol.Variable('weight')
+    bias = mx.symbol.Variable('bias')
+    conv1= mx.symbol.Convolution(data = data, weight=weight, bias=bias, name='conv1', num_filter=64, kernel=(3,3), stride=(1,1))
+    flatten1 = mx.symbol.flatten(data = conv1)
+    slice1 = mx.symbol.slice(data = flatten1, begin=0, end=1)
+
+    shape = (2, 16, 16, 16)
+    val = np.random.rand(2, 16, 16, 16).astype(np.float32)
+    exe = slice1.simple_bind(Context.default_ctx, data=shape)
+    exe.arg_arrays[0][:] = val
+    exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape)
+    exe.arg_arrays[2][:] = np.random.normal(size=exe.arg_arrays[2].shape)
+    p = exe.forward(is_train=False)
+    p[0].wait_to_read()
+    print(p[0])
+
+
 def test_mkldnn_sum_inplace_with_cpu_layout():
 
     x_shape = (32, 3, 224, 224)