You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/07/08 02:08:01 UTC
[incubator-mxnet] branch master updated: fix fp32 flatten issue
(#15351)
This is an automated email from the ASF dual-hosted git repository.
patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 091fece fix fp32 flatten issue (#15351)
091fece is described below
commit 091fece5431b43c931568d279c1dd9d664318c36
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Mon Jul 8 10:07:37 2019 +0800
fix fp32 flatten issue (#15351)
* Fix flatten issue before slice op
* fix cpplint
* address comments
* retrigger CI
* trigger CI
* retrigger CI
* use SupportMKLDNNReshape and update operator list
---
docs/tutorials/mkldnn/operator_list.md | 2 +
src/operator/nn/mkldnn/mkldnn_base-inl.h | 2 +
src/operator/nn/mkldnn/mkldnn_flatten.cc | 87 +++++++++++++
src/operator/nn/mkldnn/mkldnn_ops-inl.h | 9 +-
src/operator/nn/mkldnn/mkldnn_reshape-inl.h | 68 ++++++++++
src/operator/nn/mkldnn/mkldnn_reshape.cc | 185 +++++++++++++---------------
src/operator/tensor/matrix_op.cc | 20 ++-
tests/python/gpu/test_operator_gpu.py | 14 +++
tests/python/mkl/test_mkldnn.py | 20 +++
9 files changed, 296 insertions(+), 111 deletions(-)
diff --git a/docs/tutorials/mkldnn/operator_list.md b/docs/tutorials/mkldnn/operator_list.md
index 4958f8d..0ef0f29 100644
--- a/docs/tutorials/mkldnn/operator_list.md
+++ b/docs/tutorials/mkldnn/operator_list.md
@@ -44,6 +44,8 @@ To help users understanding MKL-DNN backend better, the following table summariz
| **elemwise_add** | 1D-4D input | Y | Y | Y |
| **Concat** | 1D-4D input | Y | Y | Y |
| **slice** | 1D-4D input | N | Y | N |
+| **Reshape** | 1D-4D input | N | Y | N |
+| **Flatten** | 1D-4D input | N | Y | N |
| **Quantization** | 1D-4D input | N | N | Y |
| **Dequantization** | 1D-4D input | N | N | Y |
| **Requantization** | 1D-4D input | N | N | Y |
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 5670983..e01b7b1 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -176,6 +176,7 @@ struct DeconvolutionParam;
struct SoftmaxParam;
struct SoftmaxOutputParam;
struct TransposeParam;
+struct ReshapeParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportQuantizedMKLDNNAct(const ActivationParam ¶m);
@@ -184,6 +185,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
+bool SupportMKLDNNReshape(const ReshapeParam ¶m, const NDArray &data);
} // namespace op
static int GetTypeSize(int dtype) {
diff --git a/src/operator/nn/mkldnn/mkldnn_flatten.cc b/src/operator/nn/mkldnn/mkldnn_flatten.cc
new file mode 100644
index 0000000..fdc02f9
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_flatten.cc
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_flatten.cc
+ * \brief Implement flatten operator by using mkldnn reorder primitive
+ * \author Wuxun Zhang
+*/
+
+#if MXNET_USE_MKLDNN == 1
+
+#include "mkldnn_reshape-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNFlattenFwd : public MKLDNNReshapeFwd {
+ public:
+ explicit MKLDNNFlattenFwd(const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output)
+ : MKLDNNReshapeFwd(req, input, output) {}
+};
+
+static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<OpSignature,
+ MKLDNNFlattenFwd, OpHash> fwds;
+#else
+ static MX_THREAD_LOCAL std::unordered_map<OpSignature,
+ MKLDNNFlattenFwd, OpHash> fwds;
+#endif
+ OpSignature key;
+ key.AddSign(req);
+ key.AddSign(input);
+
+ auto it = fwds.find(key);
+ if (it == fwds.end()) {
+ MKLDNNFlattenFwd fwd(req, input, output);
+ it = AddToCache(&fwds, key, fwd);
+ }
+ return it->second;
+}
+
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output) {
+ if (req == kNullOp) return;
+ CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";
+
+ auto fwd = GetFlattenForward(req, input, output);
+ auto ws_size = fwd.GetWorkspaceSize();
+ void* ws_ptr = nullptr;
+ if (ws_size) {
+ mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
+ mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
+ .get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
+ ws_ptr = reinterpret_cast<void*>(ws.dptr_);
+ }
+
+ fwd.Execute(input, output, ws_ptr);
+}
+
+} // namespace op
+} // namespace mxnet
+
+#endif
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index 2699a02..502abff 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -119,12 +119,17 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);
-void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs,
+void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
- const NDArray &data,
+ const NDArray &input,
const OpReqType &req,
const NDArray &output);
+void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
+ const OpContext &ctx,
+ const NDArray &input,
+ const OpReqType &req,
+ const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape-inl.h b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
new file mode 100644
index 0000000..63e367b
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_reshape-inl.h
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2019 by Contributors
+ * \file mkldnn_reshape-inl.h
+ * \brief Function definition of mkldnn reshape operator
+ */
+
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
+
+#if MXNET_USE_MKLDNN == 1
+#include <vector>
+#include "mkldnn_base-inl.h"
+#include "../../tensor/matrix_op-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNReshapeFwd {
+ protected:
+ std::shared_ptr<mkldnn::memory> data_;
+ std::shared_ptr<mkldnn::memory> out_;
+ std::shared_ptr<mkldnn::memory> temp_;
+ std::vector<mkldnn::primitive> prims_;
+ bool needInvalidateInput = false;
+
+ public:
+ MKLDNNReshapeFwd(const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output);
+ int GetWorkspaceSize();
+ void SetNewMem(const NDArray &input,
+ const NDArray &output,
+ void* workspace = nullptr);
+ void Execute(const NDArray &input,
+ const NDArray &output,
+ void* workspace = nullptr);
+};
+
+typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
+MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
+ const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output);
+
+} // namespace op
+} // namespace mxnet
+
+#endif // MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_reshape.cc b/src/operator/nn/mkldnn/mkldnn_reshape.cc
index 4f1d67a..063c85d 100644
--- a/src/operator/nn/mkldnn/mkldnn_reshape.cc
+++ b/src/operator/nn/mkldnn/mkldnn_reshape.cc
@@ -26,7 +26,7 @@
#if MXNET_USE_MKLDNN == 1
#include <mkldnn.hpp>
-#include "../../tensor/matrix_op-inl.h"
+#include "mkldnn_reshape-inl.h"
namespace mxnet {
namespace op {
@@ -43,117 +43,106 @@ bool SupportMKLDNNReshape(const ReshapeParam ¶m,
return true;
}
-typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
-
-class MKLDNNReshapeForward {
- std::shared_ptr<mkldnn::memory> data_;
- std::shared_ptr<mkldnn::memory> out_;
- std::shared_ptr<mkldnn::memory> temp_;
- std::vector<mkldnn::primitive> prims_;
-
- bool needInvalidateInput = false;
-
- public:
- MKLDNNReshapeForward(const ReshapeParam ¶m,
- const OpReqType &req,
- const NDArray &input,
- const NDArray &output) {
- auto engine = CpuEngine::Get()->get_engine();
-
- // data_
- auto in_mem = input.GetMKLDNNData();
- auto in_pd = in_mem->get_primitive_desc();
- data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
-
- // temp_
- auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
- auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
- auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
- auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
- auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
- temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
-
- // destination
- out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
-
- if (req == kWriteInplace) {
- // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
- // default layout and copy from the temporal buffer back to output buffer which has the same
- // address with input buffer.
- // If the input has default layout, then nothing need to do.
- if (input.IsMKLDNNData()) {
- prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
- prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
- needInvalidateInput = true;
- }
- } else if (req == kWriteTo) {
- if (input.IsMKLDNNData()) {
- prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
- prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
- needInvalidateInput = false;
- } else {
- prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
- needInvalidateInput = false;
- }
+MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output) {
+ auto engine = CpuEngine::Get()->get_engine();
+
+ // data_
+ auto in_mem = input.GetMKLDNNData();
+ auto in_pd = in_mem->get_primitive_desc();
+ data_ = std::make_shared<mkldnn::memory>(in_pd, nullptr);
+
+ // temp_
+ auto temp_dims = mkldnn::memory::dims(input.shape().begin(), input.shape().end());
+ auto temp_type = static_cast<mkldnn::memory::data_type>(in_pd.desc().data.data_type);
+ auto temp_fmt = static_cast<mkldnn::memory::format>(GetDefaultFormat(in_pd.desc()));
+ auto temp_desc = mkldnn::memory::desc(temp_dims, temp_type, temp_fmt);
+ auto temp_pd = mkldnn::memory::primitive_desc(temp_desc, engine);
+ temp_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+
+ // destination
+ out_ = std::make_shared<mkldnn::memory>(temp_pd, nullptr);
+
+ if (req == kWriteInplace) {
+ // If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
+ // default layout and copy from the temporal buffer back to output buffer which has the same
+ // address with input buffer.
+ // If the input has default layout, then nothing need to do.
+ if (input.IsMKLDNNData()) {
+ prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
+ prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back
+ needInvalidateInput = true;
+ }
+ } else if (req == kWriteTo) {
+ if (input.IsMKLDNNData()) {
+ prims_.push_back(mkldnn::reorder(*data_, *temp_)); // reorder to default
+ prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer
+ needInvalidateInput = false;
} else {
- LOG(FATAL) << "not supported req type: " << req;
+ prims_.push_back(mkldnn::reorder(*data_, *out_)); // copy directly from input to output
+ needInvalidateInput = false;
}
+ } else {
+ LOG(FATAL) << "not supported req type: " << req;
}
+}
- int GetWorkspaceSize() {
- return temp_ ? temp_->get_primitive_desc().get_size() : 0;
- }
+int MKLDNNReshapeFwd::GetWorkspaceSize() {
+ return temp_ ? temp_->get_primitive_desc().get_size() : 0;
+}
- void SetNewMem(const NDArray &input, const NDArray &output, void* workspace = nullptr) {
- if (input.IsMKLDNNData()) {
- this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
- } else {
- MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
- this->data_->set_data_handle(input.data().dptr<DTYPE>());
- })
- }
+void MKLDNNReshapeFwd::SetNewMem(const NDArray &input,
+ const NDArray &output,
+ void* workspace) {
+ if (input.IsMKLDNNData()) {
+ this->data_->set_data_handle(input.GetMKLDNNData()->get_data_handle());
+ } else {
+ MSHADOW_TYPE_SWITCH(input.dtype(), DTYPE, {
+ this->data_->set_data_handle(input.data().dptr<DTYPE>());
+ })
+ }
- if (output.IsMKLDNNData()) {
- this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
- } else {
- MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
- this->out_->set_data_handle(output.data().dptr<DTYPE>());
- })
- }
+ if (output.IsMKLDNNData()) {
+ this->out_->set_data_handle(output.GetMKLDNNData()->get_data_handle());
+ } else {
+ MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
+ this->out_->set_data_handle(output.data().dptr<DTYPE>());
+ })
+ }
- if (workspace) {
- this->temp_->set_data_handle(workspace);
- }
+ if (workspace) {
+ this->temp_->set_data_handle(workspace);
}
+}
- void Execute(const NDArray &input,
- const NDArray &output,
- void* workspace = nullptr) {
- // set memory handles
- SetNewMem(input, output, workspace);
- // register primitives
- auto stream = MKLDNNStream::Get();
- for (auto &v : this->prims_) {
- stream->RegisterPrim(v);
- }
- stream->Submit();
- // invalidate mkldnn memory in input
- if (needInvalidateInput) {
- const_cast<NDArray &>(input).InvalidateMKLDNNData();
- }
+void MKLDNNReshapeFwd::Execute(const NDArray &input,
+ const NDArray &output,
+ void* workspace) {
+ // set memory handles
+ SetNewMem(input, output, workspace);
+ // register primitives
+ auto stream = MKLDNNStream::Get();
+ for (auto &v : this->prims_) {
+ stream->RegisterPrim(v);
}
-};
+ stream->Submit();
+ // invalidate mkldnn memory in input
+ if (needInvalidateInput) {
+ const_cast<NDArray &>(input).InvalidateMKLDNNData();
+ }
+}
-static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param,
- const OpReqType &req,
- const NDArray &input,
- const NDArray &output) {
+MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
+ const OpReqType &req,
+ const NDArray &input,
+ const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNReshapeSignature,
- MKLDNNReshapeForward, OpHash> fwds;
+ MKLDNNReshapeFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNReshapeSignature,
- MKLDNNReshapeForward, OpHash> fwds;
+ MKLDNNReshapeFwd, OpHash> fwds;
#endif
MKLDNNReshapeSignature key(param);
key.AddSign(req);
@@ -162,7 +151,7 @@ static MKLDNNReshapeForward &GetReshapeForward(const ReshapeParam& param,
auto it = fwds.find(key);
if (it == fwds.end()) {
- MKLDNNReshapeForward fwd(param, req, input, output);
+ MKLDNNReshapeFwd fwd(req, input, output);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc
index b4abc9f..c2bcb29 100644
--- a/src/operator/tensor/matrix_op.cc
+++ b/src/operator/tensor/matrix_op.cc
@@ -111,12 +111,13 @@ static void ReshapeComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
+ const ReshapeParam& param = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
// If inputs are supposed to be in MKLDNN format and
// MKLDNNsupport the data type or the shape. Then convert
// it to the output format and shape
- if (SupportMKLDNNArray(inputs[0].dtype(), inputs[0].shape())) {
+ if (SupportMKLDNNReshape(param, inputs[0])) {
MKLDNNReshapeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
@@ -233,12 +234,9 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_MKLDNN == 1
- if (inputs[0].IsMKLDNNData()) {
- MKLDNNCopy(attrs, ctx, inputs[0], req[0], outputs[0]);
- // If the output is a special MKLDNN layout and the number of dimensions
- // is larger than 2, we should use the default layout.
- if (outputs[0].IsMKLDNNData() && inputs[0].shape().ndim() > 2)
- const_cast<NDArray &>(outputs[0]).Reorder2Default();
+ auto data_ndim = inputs[0].shape().ndim();
+ if (data_ndim <= 4 && inputs[0].dtype() == mshadow::kFloat32) {
+ MKLDNNFlattenForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
} else {
// This happens if inputs are supposed to be in MKLDNN format
@@ -252,10 +250,10 @@ static void FlattenEx(const nnvm::NodeAttrs& attrs,
#if MXNET_USE_MKLDNN == 1
static inline bool FlattenStorageType(const nnvm::NodeAttrs& attrs,
- const int dev_mask,
- DispatchMode* dispatch_mode,
- std::vector<int> *in_attrs,
- std::vector<int> *out_attrs) {
+ const int dev_mask,
+ DispatchMode* dispatch_mode,
+ std::vector<int> *in_attrs,
+ std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py
index 064f783..5b4f81d 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -1130,6 +1130,20 @@ def test_pooling_full_2d():
@with_seed()
+def test_flatten_slice_after_conv():
+ ctx_list = []
+
+ data = mx.sym.Variable('conv_data')
+ conv = mx.symbol.Convolution(data=data, name='conv', num_filter=16, kernel=(3,3), stride=(1,1))
+ flatten = mx.symbol.flatten(data=conv)
+ slice_sym = mx.symbol.slice(data=flatten, begin=0, end=1)
+
+ ctx_list = [{'ctx': mx.gpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}},
+ {'ctx': mx.cpu(0), 'conv_data': (2, 16, 16, 16), 'type_dict': {'conv_data': np.float32}}]
+ check_consistency(slice_sym, ctx_list)
+
+
+@with_seed()
def test_global_pooling():
def test_1d_pooling(pool_type, p_value=2):
data = (2, 3, 20)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index 662edcf..3e623b5 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -233,6 +233,26 @@ def test_slice_reshape_before_conv():
mx.test_utils.assert_almost_equal(out1.asnumpy(), out2.asnumpy(), rtol=1e-5, atol=1e-6)
+@with_seed()
+def test_flatten_slice_after_conv():
+ data = mx.symbol.Variable('data')
+ weight = mx.symbol.Variable('weight')
+ bias = mx.symbol.Variable('bias')
+ conv1= mx.symbol.Convolution(data = data, weight=weight, bias=bias, name='conv1', num_filter=64, kernel=(3,3), stride=(1,1))
+ flatten1 = mx.symbol.flatten(data = conv1)
+ slice1 = mx.symbol.slice(data = flatten1, begin=0, end=1)
+
+ shape = (2, 16, 16, 16)
+ val = np.random.rand(2, 16, 16, 16).astype(np.float32)
+ exe = slice1.simple_bind(Context.default_ctx, data=shape)
+ exe.arg_arrays[0][:] = val
+ exe.arg_arrays[1][:] = np.random.normal(size=exe.arg_arrays[1].shape)
+ exe.arg_arrays[2][:] = np.random.normal(size=exe.arg_arrays[2].shape)
+ p = exe.forward(is_train=False)
+ p[0].wait_to_read()
+ print(p[0])
+
+
def test_mkldnn_sum_inplace_with_cpu_layout():
x_shape = (32, 3, 224, 224)