You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/01/31 21:31:46 UTC

[incubator-mxnet] branch master updated: Refactor operators & MKLDNN (#8302)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cc2aa2  Refactor operators & MKLDNN (#8302)
2cc2aa2 is described below

commit 2cc2aa2272881326c8a50c6204aedd71e1821c3f
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Wed Jan 31 13:31:41 2018 -0800

    Refactor operators & MKLDNN (#8302)
    
    * Remove MKL code.
    
    * Use NNVM interface.
    
    Use NNVM interface for upsampling.
    
    Use NNVM interface for convolution.
    
    Use NNVM interface for deconvolution.
    
    Use NNVM interface for FullyConnected.
    
    Move NNVM interface to batch norm.
    
    Use NNVM interface for depthwise convolution.
    
    Use NNVM interface for softmax activation.
    
    Use NNVM interface for pooling.
    
    use NNVM interface for dropout.
    
    Use NNVM interface for activation.
    
    Use NNVM interface for CuDNN batch norm.
    
    Use NNVM interface for CuDNN pooling.
    
    Use NNVM interface for CuDNN softmax activation.
    
    Use NNVM interface for CuDNN activation.
    
    Use NNVM interface for CuDNN convolution.
    
    Use NNVM interface for CuDNN deconvolution.
    
    Move concat to nn/
    
    Use NNVM interface for concat.
    
    Fix headers in concat.
    
    Move lrn to nn/.
    
    Use NNVM interface for LRN.
    
    Fix a compilation error in convolution.
    
    Fix a compilation error in activation.
    
    Fix coding style.
    
    Fix coding style for make lint.
    
    use enums in batch norm.
    
    Use CoreOpRunner for refactored Ops.
    
    Make FullyConnected stateless.
    
    Make upsampling stateless.
    
    Make pooling stateless.
    
    Make batchnorm stateless.
    
    Make SoftmaxActivation stateless.
    
    Fix a code style problem.
    
    pass amalgamation test for batch norm.
    
    pass amalgamation test for dropout.
    
    Get convolution ops from a function.
    
    Fix compilation errors for GPU.
    
    Fix thread local in diff platforms.
    
    Avoid using thread_local for non-CuDNN conv/deconv.
    
    Remove TODO in deconv.
    
    Fix a bug in batch norm.
    
    Fix a bug in fully connected.
    
    Don't set #inputs for backward convolution.
    
    * Integrate MKLDNN.
    
    Update MXNet for MKLDNN.
    
    Enable MKLDNN Relu.
    
    Fix a compilation error.
    
    Change Makefile for MKLDNN.
    
    Remove infer storage in convolution.
    
    Update MXNet for MKLDNN.
    
    Support MKLDNN storage type in python.
    
    Update activation.
    
    Add MKLDNN base classes.
    
    Implement MKLDNN fully connected.
    
    Add MKLDNN convolution.
    
    Update MKLDNN interface in NDArray.
    
    MKLDNN convolution handle CreateMKLDNNData failure.
    
    Add another GetMKLDNNData in NDArray.
    
    Have mkldnn to define the data format.
    
    Create output MKLDNN memory explicitly for FC.
    
    Fix a bug in NDArray.
    
    Fix a bug in GetWeightDesc.
    
    Convert data layout if necessary in FC.
    
    remove unnecessary print in MKLDNN convolution.
    
    Add MKLDNN deconvolution.
    
    Add MKLDNNStream to manage primitives and memories.
    
    Use MKLDNNStream to register memory in NDArray.
    
    Use MKLDNNStream to manage resources in operators.
    
    Handle kAddTo in MKLDNN operators.
    
    Fix a bug in deconvolution.
    
    Fix bugs in NDArray.
    
    Revert "Fix bugs in NDArray."
    
    This reverts commit f5624a4aa9f9b9f9fe31f5e6cfa7a9752838fc4e.
    
    Fix a bug in NDArray.
    
    Fix a bug in NDArray.
    
    Reorder MKLDNN memory to default format in SetTBlob.
    
    Disable MKLDNN correctly.
    
    Fix a bug in activation.
    
    Reshape of NDArray supports MKLDNN.
    
    Fix a memory ref bug in NDArray.
    
    Reshape NDArray in MKLDNN FullyConnected.
    
    Fix data format conversion.
    
    Create MKLDNN NDArray in python.
    
    Support Slice for MKLDNN NDArray.
    
    Reduce the overhead of summing the result to the output array.
    
    Avoid unnecessary memory copy in NDArray.
    
    Fix a bug in data reordering.
    
    Fix a bug in NDArray.
    
    Don't hard code MKLDNN type.
    
    Support dilation in MKLDNN convolution.
    
    Fix a bug in sum results.
    
    Rewrite GetMKLDNNData.
    
    Add prepare_mkldnn.sh
    
    Enable MKLDNN activation.
    
    Fix a bug on FullyConnected.
    
    Handle 3 dims for MKLDNN NDArray.
    
    Fix a bug in MKLDNN FC.
    
    Support MKLDNN storage in KV store.
    
    Fix a bug in executor for non-default NDArray.
    
    Fix a link error in cast_storage.cc.
    
    Remove unnecessary function def
    
    Fall back to def storage if the type isn't supported by MKLDNN.
    
    Use NDArray for MKLDNN in python.
    
    Reshape output of MKLDNN convolution.
    
    Fix a bug in NDArray.
    
    Support more operations in MKLDNN NDArray.
    
    Fix a bug in deconvolution.
    
    Fix bugs in MKLDNN deconvolution.
    
    We still need to compute bias correctly.
    
    Have elemwise binary ops to fall to default for MKLDNN.
    
    Limit the cases that MKLDNN operations are called.
    
    Force the layout of mkldnn::memory from NDArray.
    
    Add MKLDNN softmax.
    
    Fix output storage type of MKLDNN softmax.
    
    Add MKLDNN sum.
    
    Fix a bug in elemwise sum.
    
    Fix a bug in MKLDNN softmax.
    
    Fix a bug in imperative.
    
    Clean up dispatch modes.
    
    Remove redundant code.
    
    MKLDNN Pooling Op integration
    
    MKLDNN Pooling Op integration add missing file
    
    fix mkldnn pooling op workspace issue
    
    handle workspace in MKLDNN pooling correctly.
    
    Use a non-MKLDNN op for testing.
    
    Allow to share arguments and their gradients between executors.
    
    Avoid using MKLDNN pooling when it's not supported.
    
    Support MKLDNN properly.
    
    Choose MKLDNN softmax more carefully.
    
    Fix a bug in MKLDNN pooling.
    
    Fall back if MKLDNN pooling isn't supported.
    
    Fix a bug in Slice of NDArray.
    
    Use int32 for workspace memory.
    
    Exclude MKLDNN act with tanh.
    
    Have two Reshape functions in NDArray.
    
    Copy data for NDArray with diff shapes.
    
    Add MKLDNN copy.
    
    Add MKLDNN version of elemwise_add.
    
    Add MKLDNN version of Flatten.
    
    add mkldnn surport for concat
    
    simplify MKLDNN Flatten.
    
    Enalbe MKLDNN deconvolution with bias.
    
    Fix a bug in CuDNN deconvolution.
    
    avoid using MKLDNNStorage when it's not defined.
    
    Remove ./cudnn_lrn-inl.h
    
    Fix for make lint.
    
    add mkldnn surport for concat
    
    fix the coding style for pr of mkldnn concat
    
    Only add input data for MKLDNN concat backward
    
    Remove unnecessary TODO.
    
    remove unnecessary __repr__ in MKLNDArray.
    
    better condition check for readability.
    
    Use macro when including mkldnn.hpp.
    
    Revert "Use CoreOpRunner for refactored Ops."
    
    This reverts commit a28586fc25950cc006cb317e26e0d17541ef0586.
    
    Fix a bug in test core.
    
    Limit MKLDNN ops being used.
    
    Fix complains from "make pylint"
    
    Move ContainStorage to common/utils.h
    
    Limit MKLDNN concat being used.
    
    Add license.
    
    Fix amalgamation
    
    Fix compilation error in mkldnn_ops-inl.h
    
    Fix a bug in deconvolution.
    
    Fix a bug in pooling.
    
    MKLDNN ops allocates temp mem.
    
    Fix a bug in pooling.
    
    Allocate align memory from temp space.
    
    Have parameter gradients stored in the default storage.
    
    Handle all cases in CopyFrom.
    
    Ensure NDArray returns memory with right memory descriptors.
    
    use auto to define memory in the operator.
    
    Use raw pointer for mkldnn memory.
    
    Move more code to mkldnn_base.cc
    
    Fix a compilation error.
    
    Address review comments.
    
    fix a bug in activation backward.
    
    Miss a macro in mkldnn_base.cc
    
    Fix a bug in data iterator in examples.
    
    Avoid memory allocation in ReshapeMKLDNN.
    
    Avoid memory allocation in storage cast.
    
    Fix a bug in cast storage.
    
    Handle sliced MKLDNN NDArray.
    
    Use memcpy if NDArray uses default format.
    
    Revert "Limit MKLDNN ops being used."
    
    This reverts commit 75e2ae570d03483868ec4ed8ed46015c7fa6c6fb.
    
    Enable mkldnn act backward has the same input layout.
    
    Fix a bug in mkldnn activation.
    
    Use MKLDNN sum in more cases.
    
    Improve perf of reorder.
    
    Avoid memory reorder in conv and deconv.
    
    Avoid unnecessary storage cast in fallback path.
    
    Revert "Use MKLDNN sum in more cases."
    
    This reverts commit 7a21ebca8bbe17fde49c3b1ca3f31b835a33afb8.
    
    Handle sliced ndarray in more cases.
    
    Fix a complain from make lint.
    
    Update Jenkins to test MKLDNN.
    
    debug compiling mkldnn.
    
    Use MKLDNN sum in more cases.
    
    Add mkldnn as a submodule.
    
    Compile with mkldnn in 3rdparty.
    
    Fix some coding styles.
    
    write the path to mkldnn lib in libmxnet.so.
    
    use rpath with $ORIGIN.
    
    Pack all lib files in Jenkins.
    
    pack and unpack mxnet with MKLDNN.
    
    Update Jenkinsfile
    
    Update Jenkinsfile
    
    Add mkldnn batch normalization
    
    Fix bugs in BN.
    
    Avoid memory allocation in MKLDNNCopy.
    
    only use MKLDNN BatchNorm for special cases.
    
    MKLDNN BatchNorm doesn't work well on the default layout.
    
    Add MKL-DNN based LRN
    
    Code Style Changes
    
    Fix a bug in BN.
    
    Fix a bug in LRN.
    
    Handle non-default storage in memory plan.
    
    Fix coding style.
    
    Fix a compilation error without mkldnn.
    
    Fix some coding styles for batch norm
    
    Improve forward of convolution.
    
    Add openmp and simd support to BN operator
    
    Retrieve MKLDNN Conv primitive based on signature.
    
    Retrieve Act primitive based on its signature.
    
    Fix a bug in pooling.
    
    Diable some MKLDNN activation and pooling.
    
    Cast MKLDNN storage with diff data type.
    
    Check if it's a view of NDArray.
    
    Reshaped and sliced arrays share the same chunks.
    
    Implement caching MKLDNN Act correctly.
    
    Fix a bug in check_consistency.
    
    Fix a potential bug when destroying NDArray.
    
    Fix bugs when allocating mem in NDArray.
    
    Fix coding style.
    
    Add micro when using mkldnn in ndarray.
    
    Fix a compilation error.
    
    Fix a bug in concat.
    
    Remove MKLDNNStorage.
    
    handle diff layouts in CopyFromToDnsImpl.
    
    Fallback correctly.
    
    Force weight grad to use default layout.
    
    Reorder weight arrays in (de)conv for faster inference.
    
    Avoid caching TBlob from NDArray.
    
    This commit may add some overhead of managing NDArray for each fallback.
    
    Fix a bug in Flatten.
    
    handle ndarray with def layout in mkldnn BN correctly.
    
    Align to page when mkldnn is enabled.
    
    Use default mem alloc for mkldnn.
    
    Reuse NDArrays.
    
    Support WriteInplace for sum.
    
    fix complains from "make lint".
    
    Avoid reallocation in NDArray.
    
    Handle weight arrays with special MKLDNN layouts.
    
    Remove unnecessary GetWeights.
    
    Fix compilation error without MKLDNN.
    
    Fix a bug in (de)conv for weight arrays.
    
    Fix a minor bug in MKLDNN conv.
    
    Fix a bug in MKLDNNOpSignature.
    
    Reimplement fallback for MKLDNN ops.
    
    Fix a bug in FallbackExecutor.
    
    Add params in hashcode.
    
    Invalidate data in outputs to accelerate.
    
    Fix a minor bug.
    
    Update mkldnn_base-inl.h
    
    Add primitive caching for Pooling forward computation
    
    Add hashcode in pooling parameters.
    
    Support NDArray copy with types unsupported by MKLDNN.
    
    Avoid using MKLDNN concat for negative dimension.
    
    Fix make lint complain.
    
    Disable mkldnn avg pooling for now.
    
    Fix a compile warning.
    
    Fix compile error when MKLDNN is disabled.
    
    OP primitive cache: use memory as signature for MKLDNN storage type
    
    Remove MKLDNN array in python.
    
    Disable Clang tests in Jenkins.
    
    Use mklml dockers to test mkldnn.
    
    Update MKLDNN repo to zhengda's mkldnn repo.
    
    Update MKLDNN repo to ashok's.
    
    Fix a bug in fallback.
    
    Change avg pooling algorithm to pooling_avg_include_padding
    
    Fix a code style in mkldnn pooling.
    
    Temp fix a bug in FC.
    
    Revert "Disable Clang tests in Jenkins."
    
    This reverts commit b4efa8f89592d30a27f9c30e2237e9420ac6749a.
    
    Rebase and Refactor deconv  (#20)
    
    * rebase to Da,Zheng refactor branch Jan.14,  add signature for mkldnn Deconv and modify classMKLDNNDeconvForward
    
    * fix make lint complains
    
    A simple way of caching BN inference.
    
    cache BN forward for both training and inference.
    
    Fix some minor problems in BN.
    
    Fix a bug in caching BN.
    
    force to build with avx2 in Jenkins.
    
    Remove the remaining MKLDNNStorageType
    
    Some minor updates in NDArray.
    
    a lot of updates to address comments.
    
    minor changes.
    
    * revert modification in test_executor.
    
    * Fix a bug in FlattenStorageType.
    
    * Remove BN debug.
    
    * Remove remaining MXNET_USE_MKL2017
    
    * Remove unused code in pooling.
    
    * Fixing bugs in gtests.
    
    * Fix lint errors.
    
    * a lot of minor updates to address comments.
    
    * Fix coding style in MKLDNN Pooling (#22)
    
    * revert the code change in the previous code refactor.
    
    * Fix a bug in pooling.
    
    * LRN coding style changes (#21)
    
    * LRN coding style change
    
    * Add const for local variables
    
    * Add req for LRN forward
    
    * rebase code
    
    * align API interface
    
    * revert modification in test_executor.
    
    * cast storage with MKLDNN properly.
    
    * Minor updates to address comments.
    
    * some minor updates.
    
    * Switch to the master branch of MKLDNN.
    
    * Minor updates to address comments.
    
    * Update activation.cc
    
    * Fix a bug in convert NDArray.
    
    * Add gluon model zoo tests.
    
    * Update GPU tests on model zoo.
    
    * Avoid using mobilenet for GPU tests with gluon models.
    
    mobilenet can't pass the test even without MKLDNN.
    
    * Update GPU tests on gluon.
    
    * change cmake to compile MKLDNN.
    
    * update cmake for MKLDNN.
    
    * Implement align myself.
    
    * Switch to intel/mkl-dnn.
    
    * Fix errors in align unittest.
    
    * Add unit test for LRN.
    
    * fix a compilation error.
    
    * use storage_type_assign to determine storage type.
    
    * avoid global pooling in mkldnn.
    
    There is a bug in global pooling in mkldnn.
    
    * compare all MKLDNN ops with native impls.
    
    add MXNET_MKLDNN_DEBUG to control the test.
    
    * Fix a bug in testing correctness.
    
    * print the name of buggy operator.
    
    * undo some modifications.
    
    * Fix a bug on reshaped array.
    
    * avoid testing outputs with NullOp.
    
    * turn on MKLDNN tests in Jenkins.
    
    * print each operator in MKLDNN tests.
    
    * rename test_gluon_model_zoo.py
    
    * Create hashcode for operator parameters properly.
---
 .gitmodules                                        |    4 +
 3rdparty/mkldnn                                    |    1 +
 CMakeLists.txt                                     |   14 +-
 Jenkinsfile                                        |   59 +-
 Makefile                                           |   40 +-
 amalgamation/mxnet_predict0.cc                     |    2 +-
 cmake/ChooseBlas.cmake                             |    4 +-
 cmake/Modules/FindMKL.cmake                        |   18 +-
 example/image-classification/common/data.py        |    3 +-
 include/mxnet/ndarray.h                            |  258 ++---
 include/mxnet/tensor_blob.h                        |   29 -
 prepare_mkldnn.sh                                  |  118 +++
 python/mxnet/test_utils.py                         |    4 +
 src/common/exec_utils.h                            |   71 +-
 src/executor/attach_op_execs_pass.cc               |   49 +-
 src/executor/graph_executor.cc                     |    3 +-
 src/executor/infer_graph_attr_pass.cc              |    5 -
 src/imperative/cached_op.cc                        |   10 +
 src/imperative/imperative_utils.h                  |   13 +-
 src/kvstore/kvstore_dist.h                         |   20 -
 src/ndarray/ndarray.cc                             |  599 +++++++++++-
 src/operator/concat-inl.h                          |  264 -----
 src/operator/concat.cc                             |  112 ---
 src/operator/convolution_v1.cc                     |    5 -
 src/operator/lrn-inl.h                             |  215 -----
 src/operator/lrn.cc                                |   76 --
 src/operator/lrn.cu                                |   55 --
 src/operator/mkl/mkl_batch_norm-inl.h              |  391 --------
 src/operator/mkl/mkl_concat-inl.h                  |  314 ------
 src/operator/mkl/mkl_convolution-inl.h             |  490 ----------
 src/operator/mkl/mkl_cppwrapper.cc                 |   44 -
 src/operator/mkl/mkl_cppwrapper.h                  | 1020 --------------------
 src/operator/mkl/mkl_elementwise_copy-inl.h        |   69 --
 src/operator/mkl/mkl_elementwise_sum-inl.h         |  117 ---
 src/operator/mkl/mkl_fully_connected-inl.h         |  192 ----
 src/operator/mkl/mkl_lrn-inl.h                     |  265 -----
 src/operator/mkl/mkl_memory-inl.h                  |  137 ---
 src/operator/mkl/mkl_memory.cc                     |  291 ------
 src/operator/mkl/mkl_memory.h                      |  123 ---
 src/operator/mkl/mkl_pooling-inl.h                 |  357 -------
 src/operator/mkl/mkl_relu-inl.h                    |  272 ------
 src/operator/mkl/mkl_util-inl.h                    |  110 ---
 src/operator/nn/activation-inl.h                   |  266 +++--
 src/operator/nn/activation.cc                      |  182 +++-
 src/operator/nn/activation.cu                      |   90 +-
 src/operator/nn/batch_norm-inl.h                   |  447 ++++-----
 src/operator/nn/batch_norm.cc                      |  296 ++++--
 src/operator/nn/batch_norm.cu                      |  118 ++-
 src/operator/nn/concat-inl.h                       |  160 +++
 src/operator/nn/concat.cc                          |  289 ++++++
 src/operator/{ => nn}/concat.cu                    |   14 +-
 src/operator/nn/convolution-inl.h                  |  350 +------
 src/operator/nn/convolution.cc                     |  433 ++++++++-
 src/operator/nn/convolution.cu                     |  160 ++-
 src/operator/nn/cudnn/cudnn_activation-inl.h       |  125 ++-
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h       |  158 +--
 src/operator/nn/cudnn/cudnn_batch_norm.cc          |  104 +-
 src/operator/nn/cudnn/cudnn_batch_norm.cu          |   58 +-
 src/operator/nn/cudnn/cudnn_convolution-inl.h      |   60 +-
 src/operator/nn/cudnn/cudnn_deconvolution-inl.h    |   64 +-
 src/operator/nn/cudnn/cudnn_pooling-inl.h          |  283 +++---
 .../nn/cudnn/cudnn_softmax_activation-inl.h        |  102 +-
 src/operator/nn/deconvolution-inl.h                |  360 ++-----
 src/operator/nn/deconvolution.cc                   |  409 +++++++-
 src/operator/nn/deconvolution.cu                   |  109 ++-
 src/operator/nn/depthwise_convolution-inl.h        |   36 +-
 src/operator/nn/depthwise_convolution_tf.cuh       |    6 +-
 src/operator/nn/dropout-inl.h                      |  163 +---
 src/operator/nn/dropout.cc                         |   87 +-
 src/operator/nn/dropout.cu                         |   17 +-
 src/operator/nn/fully_connected-inl.h              |  357 +++----
 src/operator/nn/fully_connected.cc                 |  217 ++++-
 src/operator/nn/fully_connected.cu                 |   50 +-
 src/operator/nn/lrn-inl.h                          |  127 +++
 src/operator/nn/lrn.cc                             |  203 ++++
 src/operator/nn/{dropout.cu => lrn.cu}             |   19 +-
 src/operator/nn/mkldnn/mkldnn_act.cc               |  217 +++++
 src/operator/nn/mkldnn/mkldnn_base-inl.h           |  488 ++++++++++
 src/operator/nn/mkldnn/mkldnn_base.cc              |  385 ++++++++
 src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h     |  431 +++++++++
 src/operator/nn/mkldnn/mkldnn_concat.cc            |   95 ++
 src/operator/nn/mkldnn/mkldnn_convolution.cc       |  357 +++++++
 src/operator/nn/mkldnn/mkldnn_copy.cc              |   57 ++
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc     |  390 ++++++++
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   |  200 ++++
 src/operator/nn/mkldnn/mkldnn_lrn-inl.h            |  143 +++
 src/operator/nn/mkldnn/mkldnn_ops-inl.h            |  114 +++
 src/operator/nn/mkldnn/mkldnn_pooling-inl.h        |  121 +++
 src/operator/nn/mkldnn/mkldnn_pooling.cc           |  322 ++++++
 src/operator/nn/mkldnn/mkldnn_softmax.cc           |   56 ++
 src/operator/nn/mkldnn/mkldnn_sum.cc               |   73 ++
 src/operator/nn/pooling-inl.h                      |  353 +++----
 src/operator/nn/pooling.cc                         |  375 ++++++-
 src/operator/nn/pooling.cu                         |   96 +-
 src/operator/nn/softmax.cc                         |   47 +
 src/operator/nn/softmax_activation-inl.h           |  205 ++--
 src/operator/nn/softmax_activation.cc              |   30 +-
 src/operator/nn/softmax_activation.cu              |   49 +-
 src/operator/nn/upsampling-inl.h                   |  359 +++----
 src/operator/nn/upsampling.cc                      |  162 +++-
 src/operator/nn/upsampling.cu                      |   38 +-
 src/operator/tensor/cast_storage-inl.h             |   28 +-
 src/operator/tensor/elemwise_binary_op_basic.cc    |  106 +-
 .../tensor/elemwise_binary_scalar_op_basic.cc      |    4 +-
 src/operator/tensor/elemwise_sum.cc                |   42 +-
 src/operator/tensor/elemwise_unary_op_basic.cc     |   66 +-
 src/operator/tensor/matrix_op.cc                   |   54 ++
 src/storage/cpu_device_storage.h                   |    6 +
 tests/ci_build/ci_build.sh                         |    1 +
 tests/cpp/include/test_core_op.h                   |   17 +-
 tests/cpp/include/test_op_runner.h                 |    5 +-
 tests/cpp/operator/activation_perf.cc              |   28 +-
 tests/cpp/operator/batchnorm_test.cc               |    5 +
 tests/cpp/operator/dropout_perf.cc                 |   36 +-
 tests/cpp/operator/fully_conn_perf.cc              |   48 +-
 tests/cpp/operator/mkldnn.cc                       |   73 ++
 tests/python/gpu/test_gluon_model_zoo_gpu.py       |  163 ++++
 tests/python/gpu/test_operator_gpu.py              |    7 +
 118 files changed, 9735 insertions(+), 8279 deletions(-)

diff --git a/.gitmodules b/.gitmodules
index 170c105..42f0027 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -22,3 +22,7 @@
 [submodule "3rdparty/googletest"]
 	path = 3rdparty/googletest
 	url = https://github.com/google/googletest.git
+[submodule "3rdparty/mkldnn"]
+	path = 3rdparty/mkldnn
+	url = https://github.com/intel/mkl-dnn.git
+	branch = master
diff --git a/3rdparty/mkldnn b/3rdparty/mkldnn
new file mode 160000
index 0000000..3e1f8f5
--- /dev/null
+++ b/3rdparty/mkldnn
@@ -0,0 +1 @@
+Subproject commit 3e1f8f53f6845dce23abf8089501c2eb45420b9e
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 14b40e4..dfa9834 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -33,8 +33,8 @@ mxnet_option(USE_OPENMP           "Build with Openmp support" ON)
 mxnet_option(USE_CUDNN            "Build with cudnn support"  ON) # one could set CUDNN_ROOT for search path
 mxnet_option(USE_LAPACK           "Build with lapack support" ON IF NOT MSVC)
 mxnet_option(USE_MKL_IF_AVAILABLE "Use MKL if found" ON)
-mxnet_option(USE_MKLML_MKL        "Use MKLML variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
-mxnet_option(USE_MKL_EXPERIMENTAL "Use experimental MKL (if MKL enabled and found)" OFF)
+mxnet_option(USE_MKLDNN           "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
+mxnet_option(USE_MKLML_MKL        "Use MKLDNN variant of MKL (if MKL found)" ON IF USE_MKL_IF_AVAILABLE AND UNIX AND (NOT APPLE))
 mxnet_option(USE_OPERATOR_TUNING  "Enable auto-tuning of operators" ON AND NOT MSVC)
 mxnet_option(USE_GPERFTOOLS       "Build with GPerfTools support (if found)" ON)
 mxnet_option(USE_JEMALLOC         "Build with Jemalloc support"   ON)
@@ -138,14 +138,11 @@ if(USE_VTUNE)
 endif()
 
 if(USE_MKL_IF_AVAILABLE)
-  if(USE_MKL_EXPERIMENTAL AND NOT USE_MKLML_MKL)
-    message(ERROR " USE_MKL_EXPERIMENTAL can only be used when USE_MKL_EXPERIMENTAL is enabled")
-  endif()
   find_package(MKL)
   if(MKL_FOUND)
     include_directories(${MKL_INCLUDE_DIR})
     include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src/operator/mkl)
-    add_definitions(-DMXNET_USE_MKL2017=1)
+	add_definitions(-DMXNET_USE_MKLDNN=1)
     add_definitions(-DUSE_MKL=1)
     add_definitions(-DCUB_MKL=1)
     list(APPEND mxnet_LINKER_LIBS ${MKL_LIBRARIES})
@@ -154,11 +151,6 @@ if(USE_MKL_IF_AVAILABLE)
     endif()
     # If using MKL, use the Intel OMP libraries
     list(APPEND mxnet_LINKER_LIBS iomp5)
-    if(USE_MKL_EXPERIMENTAL)
-      add_definitions(-DMKL_EXPERIMENTAL=1)
-    else()
-      add_definitions(-DMKL_EXPERIMENTAL=0)
-    endif()
   else()
     message(STATUS " MKL not found")
   endif()
diff --git a/Jenkinsfile b/Jenkinsfile
index 05cda74..80f9424 100644
--- a/Jenkinsfile
+++ b/Jenkinsfile
@@ -24,6 +24,7 @@
 mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, dmlc-core/libdmlc.a, nnvm/lib/libnnvm.a'
 // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
 mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/dmlc-core/libdmlc.a'
+mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmklml_gnu.so, lib/libmkldnn.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, dmlc-core/libdmlc.a, nnvm/lib/libnnvm.a'
 // command to start a docker container
 docker_run = 'tests/ci_build/ci_build.sh'
 // timeout in minutes
@@ -161,18 +162,18 @@ def python3_gpu_ut(docker_type) {
 }
 
 // Python 2
-def python2_mklml_ut(docker_type) {
+def python2_mkldnn_ut(docker_type) {
   timeout(time: max_time, unit: 'MINUTES') {
     sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete"
-    sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-2.7 --with-timer --verbose tests/python/cpu"
+    sh "${docker_run} ${docker_type} PYTHONPATH=./python/ MXNET_MKLDNN_DEBUG=1 nosetests-2.7 --with-timer --verbose tests/python/cpu"
   }
 }
 
 // Python 3
-def python3_mklml_ut(docker_type) {
+def python3_mkldnn_ut(docker_type) {
   timeout(time: max_time, unit: 'MINUTES') {
     sh "${docker_run} ${docker_type} find . -name '*.pyc' -type f -delete"
-    sh "${docker_run} ${docker_type} PYTHONPATH=./python/ nosetests-3.4 --with-timer --verbose tests/python/cpu"
+    sh "${docker_run} ${docker_type} PYTHONPATH=./python/ MXNET_MKLDNN_DEBUG=1 nosetests-3.4 --with-timer --verbose tests/python/cpu"
   }
 }
 
@@ -243,21 +244,20 @@ try {
         }
       }
     },
-    'CPU: MKLML': {
+    'CPU: MKLDNN': {
       node('mxnetlinux-cpu') {
-        ws('workspace/build-mklml-cpu') {
+        ws('workspace/build-mkldnn-cpu') {
           init_git()
           def flag = """ \
             DEV=1                         \
             USE_PROFILER=1                \
             USE_CPP_PACKAGE=1             \
             USE_BLAS=openblas             \
-            USE_MKL2017=1                 \
-            USE_MKL2017_EXPERIMENTAL=1    \
+            USE_MKLDNN=1                  \
             -j\$(nproc)
             """
           make("cpu_mklml", flag)
-          pack_lib('mklml_cpu')
+          pack_lib('mkldnn_cpu', mx_mkldnn_lib)
         }
       }
     },
@@ -278,24 +278,23 @@ try {
         }
       }
     },
-    'GPU: MKLML': {
+    'GPU: MKLDNN': {
       node('mxnetlinux-cpu') {
-        ws('workspace/build-mklml-gpu') {
+        ws('workspace/build-mkldnn-gpu') {
           init_git()
           def flag = """ \
             DEV=1                         \
             USE_PROFILER=1                \
             USE_CPP_PACKAGE=1             \
             USE_BLAS=openblas             \
-            USE_MKL2017=1                 \
-            USE_MKL2017_EXPERIMENTAL=1    \
+            USE_MKLDNN=1                  \
             USE_CUDA=1                    \
             USE_CUDA_PATH=/usr/local/cuda \
             USE_CUDNN=1                   \
             -j\$(nproc)
             """
           make("build_cuda", flag)
-          pack_lib('mklml_gpu')
+          pack_lib('mkldnn_gpu', mx_mkldnn_lib)
         }
       }
     },
@@ -442,43 +441,43 @@ try {
         }
       }
     },
-    'Python2: MKLML-CPU': {
+    'Python2: MKLDNN-CPU': {
       node('mxnetlinux-cpu') {
-        ws('workspace/ut-python2-mklml-cpu') {
+        ws('workspace/ut-python2-mkldnn-cpu') {
           init_git()
-          unpack_lib('mklml_cpu')
+          unpack_lib('mkldnn_cpu', mx_mkldnn_lib)
           python2_ut('cpu_mklml')
-          python2_mklml_ut('cpu_mklml')
+          python2_mkldnn_ut('cpu_mklml')
         }
       }
     },
-    'Python2: MKLML-GPU': {
+    'Python2: MKLDNN-GPU': {
       node('mxnetlinux-gpu') {
-        ws('workspace/ut-python2-mklml-gpu') {
+        ws('workspace/ut-python2-mkldnn-gpu') {
           init_git()
-          unpack_lib('mklml_gpu')
+          unpack_lib('mkldnn_gpu', mx_mkldnn_lib)
           python2_gpu_ut('gpu_mklml')
-          python2_mklml_ut('gpu_mklml')
+          python2_mkldnn_ut('gpu_mklml')
         }
       }
     },
-    'Python3: MKLML-CPU': {
+    'Python3: MKLDNN-CPU': {
       node('mxnetlinux-cpu') {
-        ws('workspace/ut-python3-mklml-cpu') {
+        ws('workspace/ut-python3-mkldnn-cpu') {
           init_git()
-          unpack_lib('mklml_cpu')
+          unpack_lib('mkldnn_cpu', mx_mkldnn_lib)
           python3_ut('cpu_mklml')
-          python3_mklml_ut('cpu_mklml')
+          python3_mkldnn_ut('cpu_mklml')
         }
       }
     },
-    'Python3: MKLML-GPU': {
+    'Python3: MKLDNN-GPU': {
       node('mxnetlinux-gpu') {
-        ws('workspace/ut-python3-mklml-gpu') {
+        ws('workspace/ut-python3-mkldnn-gpu') {
           init_git()
-          unpack_lib('mklml_gpu')
+          unpack_lib('mkldnn_gpu', mx_mkldnn_lib)
           python3_gpu_ut('gpu_mklml')
-          python3_mklml_ut('gpu_mklml')
+          python3_mkldnn_ut('gpu_mklml')
         }
       }
     },
diff --git a/Makefile b/Makefile
index 976035b..d325aa6 100644
--- a/Makefile
+++ b/Makefile
@@ -59,11 +59,11 @@ endif
 # use customized config file
 include $(config)
 
-ifeq ($(USE_MKL2017), 1)
-# must run ./prepare_mkl before including mshadow.mk
-	RETURN_STRING := $(shell ./prepare_mkl.sh $(MKLML_ROOT))
-	MKLROOT := $(firstword $(RETURN_STRING))
-	export USE_MKLML = $(lastword $(RETURN_STRING))
+ifeq ($(USE_MKLDNN), 1)
+	RETURN_STRING := $(shell ./prepare_mkldnn.sh $(MKLDNN_ROOT))
+	MKLDNNROOT := $(firstword $(RETURN_STRING))
+	MKLROOT := $(lastword $(RETURN_STRING))
+	export USE_MKLML = 1
 endif
 
 include mshadow/make/mshadow.mk
@@ -131,23 +131,16 @@ ifeq ($(USE_NNPACK), 1)
 	LDFLAGS += -lnnpack
 endif
 
-ifeq ($(USE_MKL2017), 1)
-	CFLAGS += -DMXNET_USE_MKL2017=1
+ifeq ($(USE_MKLDNN), 1)
+	CFLAGS += -DMXNET_USE_MKLDNN=1
 	CFLAGS += -DUSE_MKL=1
-	CFLAGS += -I$(ROOTDIR)/src/operator/mkl/
-	CFLAGS += -I$(MKLML_ROOT)/include
-	LDFLAGS += -L$(MKLML_ROOT)/lib
-	ifeq ($(USE_MKL2017_EXPERIMENTAL), 1)
-		CFLAGS += -DMKL_EXPERIMENTAL=1
-	else
-		CFLAGS += -DMKL_EXPERIMENTAL=0
-	endif
-	ifeq ($(UNAME_S), Darwin)
-		LDFLAGS += -lmklml
-	else
-		LDFLAGS += -Wl,--as-needed -lmklml_intel -lmklml_gnu
+	CFLAGS += -I$(ROOTDIR)/src/operator/nn/mkldnn/
+	ifneq ($(MKLDNNROOT), $(MKLROOT))
+		CFLAGS += -I$(MKLROOT)/include
+		LDFLAGS += -L$(MKLROOT)/lib
 	endif
-	LDFLAGS +=  -liomp5
+	CFLAGS += -I$(MKLDNNROOT)/include
+	LDFLAGS += -L$(MKLDNNROOT)/lib -lmkldnn -Wl,-rpath,'$${ORIGIN}'
 endif
 
 ifeq ($(USE_OPERATOR_TUNING), 1)
@@ -161,7 +154,7 @@ endif
 #   -  for Ubuntu, installing atlas will not automatically install the atlas provided lapack library
 # silently switching lapack off instead of letting the build fail because of backward compatibility
 ifeq ($(USE_LAPACK), 1)
-ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas))
+ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas mkl))
 ifeq (,$(wildcard /lib/liblapack.a))
 ifeq (,$(wildcard /usr/lib/liblapack.a))
 ifeq (,$(wildcard /usr/lib64/liblapack.a))
@@ -179,7 +172,7 @@ ifeq ($(USE_LAPACK), 1)
 	ifneq ($(USE_LAPACK_PATH), )
 		LDFLAGS += -L$(USE_LAPACK_PATH)
 	endif
-	ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas))
+	ifeq ($(USE_BLAS),$(filter $(USE_BLAS),blas openblas atlas mkl))
 		LDFLAGS += -llapack
 	endif
 	CFLAGS += -DMXNET_USE_LAPACK
@@ -569,7 +562,8 @@ clean: cyclean $(EXTRA_PACKAGES_CLEAN)
 else
 clean: cyclean testclean $(EXTRA_PACKAGES_CLEAN)
 	$(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ R-package/NAMESPACE R-package/man R-package/R/mxnet_generated.R \
-		R-package/inst R-package/src/image_recordio.h R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz
+		R-package/inst R-package/src/image_recordio.h R-package/src/*.o R-package/src/*.so mxnet_*.tar.gz \
+		external/mkldnn/install/*
 	cd $(DMLC_CORE); $(MAKE) clean; cd -
 	cd $(PS_PATH); $(MAKE) clean; cd -
 	cd $(NNVM_PATH); $(MAKE) clean; cd -
diff --git a/amalgamation/mxnet_predict0.cc b/amalgamation/mxnet_predict0.cc
index f35591d..cfee605 100644
--- a/amalgamation/mxnet_predict0.cc
+++ b/amalgamation/mxnet_predict0.cc
@@ -66,7 +66,7 @@
 #include "src/operator/operator_util.cc"
 #include "src/operator/nn/activation.cc"
 #include "src/operator/nn/batch_norm.cc"
-#include "src/operator/concat.cc"
+#include "src/operator/nn/concat.cc"
 #include "src/operator/nn/convolution.cc"
 #include "src/operator/nn/deconvolution.cc"
 #include "src/operator/nn/dropout.cc"
diff --git a/cmake/ChooseBlas.cmake b/cmake/ChooseBlas.cmake
index 3a8723a..13d7083 100644
--- a/cmake/ChooseBlas.cmake
+++ b/cmake/ChooseBlas.cmake
@@ -23,7 +23,7 @@ if(USE_MKL_IF_AVAILABLE)
     find_package(MKL)
   endif()
   if(MKL_FOUND)
-    if(USE_MKLML_MKL)
+	if(USE_MKLDNN)
       set(BLAS "open")
     else()
       set(BLAS "MKL")
@@ -55,4 +55,4 @@ elseif(BLAS STREQUAL "apple")
   list(APPEND mshadow_LINKER_LIBS ${Accelerate_LIBRARIES})
   add_definitions(-DMSHADOW_USE_MKL=0)
   add_definitions(-DMSHADOW_USE_CBLAS=1)
-endif()
\ No newline at end of file
+endif()
diff --git a/cmake/Modules/FindMKL.cmake b/cmake/Modules/FindMKL.cmake
index 743a871..7040556 100644
--- a/cmake/Modules/FindMKL.cmake
+++ b/cmake/Modules/FindMKL.cmake
@@ -19,7 +19,7 @@
 #
 # Options:
 #
-#   USE_MKLML_MKL                   : Search for MKL:ML library variant
+#   USE_MKLDNN                    : Search for MKL:ML library variant
 #
 #   MKL_USE_SINGLE_DYNAMIC_LIBRARY  : use single dynamic library interface
 #   MKL_USE_STATIC_LIBS             : use static libraries
@@ -33,7 +33,7 @@
 #   MKL_INCLUDE_DIR      : unclude directory
 #   MKL_LIBRARIES        : the libraries to link against.
 #
-# cjolivier01: Changed to also look for MKLML library (subset of mkl) instead of standard MKL package
+# cjolivier01: Changed to also look for MKLDNN library (subset of mkl) instead of standard MKL package
 #
 
 if(MKL_FOUND)
@@ -43,7 +43,7 @@ endif()
 # ---[ Root folders
 set(INTEL_ROOT "/opt/intel" CACHE PATH "Folder contains intel libs")
 
-if(USE_MKLML_MKL)
+if(USE_MKLDNN)
 
   find_path(MKL_ROOT include/mkl_blas.h
     PATHS $ENV{MKL_ROOT}
@@ -66,13 +66,14 @@ if(USE_MKLML_MKL)
   set(__mkl_libs "")
 
   if(WIN32)
-    list(APPEND __mkl_libs intel)
+    list(APPEND __mkl_libs mklml_intel)
   else()
-    list(APPEND __mkl_libs gnu)
+    list(APPEND __mkl_libs mklml_gnu)
   endif()
+  list(APPEND __mkl_libs mkldnn)
 
   foreach (__lib ${__mkl_libs})
-    set(__mkl_lib "mklml_${__lib}")
+    set(__mkl_lib "${__lib}")
     string(TOUPPER ${__mkl_lib} __mkl_lib_upper)
 
     if(MKL_USE_STATIC_LIBS)
@@ -90,8 +91,7 @@ if(USE_MKLML_MKL)
     list(APPEND MKL_LIBRARIES ${${__mkl_lib_upper}_LIBRARY})
   endforeach()
 
-
-else(USE_MKLML_MKL)
+else(USE_MKLDNN)
 
   # ---[ Options
   mxnet_option(MKL_USE_SINGLE_DYNAMIC_LIBRARY "Use single dynamic library interface" ON)
@@ -193,7 +193,7 @@ else(USE_MKLML_MKL)
     list(APPEND MKL_LIBRARIES ${MKL_RTL_LIBRARY})
   endif()
 
-endif(USE_MKLML_MKL)
+endif(USE_MKLDNN)
 
 include(FindPackageHandleStandardArgs)
 find_package_handle_standard_args(MKL DEFAULT_MSG ${__looked_for})
diff --git a/example/image-classification/common/data.py b/example/image-classification/common/data.py
index dc8915c..05f5ddc 100755
--- a/example/image-classification/common/data.py
+++ b/example/image-classification/common/data.py
@@ -112,7 +112,8 @@ def get_rec_iter(args, kv=None):
     image_shape = tuple([int(l) for l in args.image_shape.split(',')])
     if 'benchmark' in args and args.benchmark:
         data_shape = (args.batch_size,) + image_shape
-        train = SyntheticDataIter(args.num_classes, data_shape, 500, np.float32)
+        train = SyntheticDataIter(args.num_classes, data_shape,
+                args.num_examples / args.batch_size, np.float32)
         return (train, None)
     if kv:
         (rank, nworker) = (kv.rank, kv.num_workers)
diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index a18d2da..43bc205 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -35,12 +35,13 @@
 #include <map>
 #include <string>
 #include <memory>
+#include <algorithm>
+#if MXNET_USE_MKLDNN == 1
+#include <mkldnn.hpp>
+#endif
 #include "./base.h"
 #include "./storage.h"
 #include "./engine.h"
-#if MKL_EXPERIMENTAL == 1
-#include <mkl_memory.h>
-#endif
 // check c++11
 #if DMLC_USE_CXX11 == 0
 #error "cxx11 was required for ndarray module"
@@ -72,6 +73,7 @@ enum NDArrayFormatErr {
   kRSPIdxErr,     // indices error for row sparse
 };
 
+class MKLDNNMemory;
 
 /*!
  * \brief ndarray interface
@@ -80,9 +82,6 @@ class NDArray {
  public:
   /*! \brief default constructor */
   NDArray() {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = MKLMemHolder::create();
-#endif
   }
   /*!
    * \brief constructs a new dynamic NDArray
@@ -96,56 +95,14 @@ class NDArray {
       : ptr_(std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
         shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage),
         entry_({nullptr, 0, 0}) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = std::make_shared<MKLMemHolder>();
-#endif
   }
   /*! \brief constructor for NDArray with storage type
    */
   NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx,
           bool delay_alloc = true, int dtype = mshadow::default_type_flag,
           std::vector<int> aux_types = {}, std::vector<TShape> aux_shapes = {},
-          TShape storage_shape = TShape(mshadow::Shape1(0)))
-      : shape_(shape), dtype_(dtype), storage_type_(stype),
-        entry_({nullptr, 0, 0}) {
-      // Assign default aux types if not given
-      if (aux_types.size() == 0) {
-        if (stype == kRowSparseStorage) {
-          aux_types = {mshadow::kInt64};
-        } else if (stype == kCSRStorage) {
-          aux_types = {mshadow::kInt64, mshadow::kInt64};
-        } else {
-          LOG(FATAL) << "Unknown storage type " << stype;
-        }
-      }
-      // Assign default shapes if not given
-      // unknown shapes are intialized as {0} such that Size() would return 0
-      if (aux_shapes.size() == 0) {
-        if (stype == kRowSparseStorage) {
-          aux_shapes = {TShape(mshadow::Shape1(0))};
-        } else if (stype == kCSRStorage) {
-          // aux shapes for indptr and indices
-          aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))};
-        } else {
-          LOG(FATAL) << "Unknown storage type " << stype;
-        }
-      }
-      if (storage_shape.Size() == 0) {
-        if (stype == kRowSparseStorage) {
-          storage_shape = shape;
-          storage_shape[0] = aux_shapes[rowsparse::kIdx][0];
-        } else if (stype == kCSRStorage) {
-          storage_shape = aux_shapes[csr::kIdx];
-        } else {
-          LOG(FATAL) << "Unknown storage type " << stype;
-        }
-      }
-      ptr_ = std::make_shared<Chunk>(stype, storage_shape, ctx, delay_alloc,
-                                     dtype, aux_types, aux_shapes);
-#if MKL_EXPERIMENTAL == 1
-      Mkl_mem_ = std::make_shared<MKLMemHolder>();
-#endif
-  }
+          TShape storage_shape = TShape(mshadow::Shape1(0)));
+
   /*!
    * \brief constructing a static NDArray that shares data with TBlob
    *  Use with caution: allocate ONLY ONE NDArray for each TBlob,
@@ -157,17 +114,11 @@ class NDArray {
       : ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_),
         dtype_(data.type_flag_), storage_type_(kDefaultStorage),
         entry_({nullptr, 0, 0}) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = std::make_shared<MKLMemHolder>();
-#endif
   }
   /*! \brief create ndarray from shared memory */
   NDArray(int shared_pid, int shared_id, const TShape& shape, int dtype)
       : ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), shape_(shape),
         dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = std::make_shared<MKLMemHolder>();
-#endif
   }
 
   /*!
@@ -184,11 +135,24 @@ class NDArray {
           const TBlob &data, const std::vector<TBlob> &aux_data, int dev_id)
       : ptr_(std::make_shared<Chunk>(stype, data, aux_data, dev_id)), shape_(shape),
         dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = std::make_shared<MKLMemHolder>();
-#endif
   }
 
+  /*
+   * This indicates whether an array is a view of another array (created by
+   * reshape or slice). If an array is a view and the the data is stored in
+   * MKLDNN format, we need to convert the data to the default format when
+   * data in the view is accessed.
+   */
+  inline bool IsView() const {
+    // View only works on the default storage
+    if (storage_type() != kDefaultStorage)
+      return false;
+    // If the array reuses memory, its shape may be different from the storage
+    // shape. However, we shouldn't consider it as a view.
+    if (reuse_)
+      return false;
+    return byte_offset_ > 0 || shape() != ptr_->storage_shape;
+  }
 
   /*!
    * \return the shape of current NDArray.
@@ -271,9 +235,6 @@ class NDArray {
             << "Unexpected storage type: " << stype;
       res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type);
     });
-#if MKL_EXPERIMENTAL == 1
-    res.Mkl_mem_ = Mkl_mem_;
-#endif
     return res;
   }
   /*!
@@ -534,15 +495,12 @@ class NDArray {
     CHECK_GE(ptr_->shandle.size,
              shape.Size() * mshadow::mshadow_sizeof(dtype))
         << "NDArray.AsArray: target memory size is bigger";
-#if MKL_EXPERIMENTAL == 1
-    if (Mkl_mem_ != nullptr) {
-      // convert prv to cpu
-      Mkl_mem_->check_and_prv_to_cpu(ptr_->shandle.dptr);
-    }
-#endif
+    // We can't reuse memory in a view.
+    CHECK(!IsView());
     NDArray ret = *this;
     ret.shape_ = shape;
     ret.dtype_ = dtype;
+    ret.reuse_ = true;
     return ret;
   }
   /*!
@@ -611,6 +569,83 @@ class NDArray {
              << "CheckAndAllocAuxData is not intended for kDefaultStorage";
     ptr_->CheckAndAllocAuxData(i, aux_shape);
   }
+
+#if MXNET_USE_MKLDNN == 1
+  /*
+   * Test if the data is stored in one of special MKLDNN format.
+   */
+  bool IsMKLDNNData() const {
+    return ptr_->IsMKLDNN();
+  }
+  /*
+   * Test if the data is stored in one of default MXNet formats.
+   */
+  bool IsDefaultData() const {
+    return ptr_->IsDefault();
+  }
+  /*
+   * All functions below return a raw pointer to mkldnn memory. Actually there
+   * is a shared pointer that hold the memory either in NDArray or in MKLDNN
+   * stream. As long as we call these functions inside an operator, the return
+   * memory is always valid.
+   */
+
+  /*
+   * This function returns mkldnn::memory with the default primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNData() const;
+  /*
+   * This function returns mkldnn::memory with the given primitive_desc
+   * as long as the array size meets the required size in the given primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNData(
+      const mkldnn::memory::primitive_desc &desc) const;
+  /*
+   * This function returns mkldnn::memory with the given primitive_desc.
+   * The returned mkldnn::memory will have the same physical layout as
+   * the given primitive_desc.
+   */
+  const mkldnn::memory *GetMKLDNNDataReorder(
+      const mkldnn::memory::primitive_desc &desc) const;
+
+  /*
+   * This function copies data from mkldnn memory.
+   */
+  void CopyFrom(const mkldnn::memory &mem);
+  /*
+   * This function allocates memory for array and creates mkldnn memory
+   * with the specified format.
+   */
+  mkldnn::memory *CreateMKLDNNData(
+      const mkldnn::memory::primitive_desc &desc);
+
+  /*
+   * Reorder the memory to the specified layout.
+   */
+  void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc);
+  void Reorder2Default() {
+    CHECK_EQ(storage_type(), kDefaultStorage);
+    ptr_->Reorder2Default();
+  }
+
+  void InvalidateMKLDNNData() {
+    // Removing mkl_mem_ means the NDArray will store data in the default format.
+    ptr_->mkl_mem_ = nullptr;
+  }
+
+  /*
+   * This function is used inside operators to reshape an array.
+   * It doesn't change the layout of the original array and allocate memory from
+   * the temporary buffer. The returned array is only valid inside the current
+   * invocation of this operator.
+   * This is different from Reshape. Reshape will cause data in the array to be
+   * converted to the default layout and allocate memory from malloc directly,
+   * which can be expensive.
+   * It's used by FullyConnected right now.
+   */
+  NDArray MKLDNNDataReshape(const TShape &shape) const;
+#endif
+
   /*!
    * \brief Save list of ndarray into the Stream.x
    * \param fo The stream of output.
@@ -645,6 +680,12 @@ class NDArray {
                for csr, aux_handles[0] = indptr, aux_handles[1] = indices
     */
     std::vector<Storage::Handle> aux_handles;
+
+#if MXNET_USE_MKLDNN == 1
+    /*! This is created when data is stored in MKLDNN format.
+     */
+    std::shared_ptr<mkldnn::memory> mkl_mem_;
+#endif
     /*! \brief variable from engine */
     Engine::VarHandle var;
     /*!
@@ -706,7 +747,7 @@ class NDArray {
         : static_data(false), delay_alloc(false) {
       var = Engine::Get()->NewVariable();
       ctx = Context::CPUShared(0);
-      shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);;
+      shandle.size = shape.Size() * mshadow::mshadow_sizeof(dtype);
       shandle.ctx = ctx;
       shandle.shared_pid = shared_pid;
       shandle.shared_id = shared_id;
@@ -781,6 +822,9 @@ class NDArray {
     inline void CheckAndAlloc(void) {
       if (delay_alloc) {
         shandle = Storage::Get()->Alloc(shandle.size, shandle.ctx);
+#if MXNET_USE_MKLDNN == 1
+        mkl_mem_ = nullptr;
+#endif
         delay_alloc = false;
       }
     }
@@ -789,15 +833,22 @@ class NDArray {
     // size is the number of bytes
     void CheckAndAlloc(uint64_t dbytes) {
       CHECK_EQ(kDefaultStorage, storage_type)
-              << "CheckAndAlloc(dbytes) is not intended for kDefaultStorage";
+          << "CheckAndAlloc(dbytes) is only intended for kDefaultStorage";
+      dbytes = std::max(dbytes, shandle.size);
       if (delay_alloc) {
         shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+#if MXNET_USE_MKLDNN == 1
+        mkl_mem_ = nullptr;
+#endif
         delay_alloc = false;
       } else if (shandle.size < dbytes) {
         // free storage if necessary and alloc again
         if (shandle.size > 0) Storage::Get()->Free(shandle);
         // init storage
         shandle = Storage::Get()->Alloc(dbytes, shandle.ctx);
+#if MXNET_USE_MKLDNN == 1
+        mkl_mem_ = nullptr;
+#endif
       }
     }
 
@@ -823,20 +874,19 @@ class NDArray {
     // storage shape is also updated
     // if data is already allocated, try reuse the storage. Otherwise, free the current one
     // and allocate new storage
-    inline void CheckAndAllocData(const TShape &shape, int dtype) {
-      CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data";
-      auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
-      if (shandle.size < dbytes) {
-        // free storage if necessary and alloc again
-        if (shandle.size > 0) Storage::Get()->Free(shandle);
-        // init storage
-        shandle = Storage::Get()->Alloc(dbytes, ctx);
-      }
-      // init shape
-      storage_shape = shape;
-      // delay_alloc is only set when data storage handle is present
-      delay_alloc = false;
-    }
+    void CheckAndAllocData(const TShape &shape, int dtype);
+
+#if MXNET_USE_MKLDNN == 1
+    // Have MKL memory reference to the data in the default storage
+    // or create memory for MKLDNN.
+    void SetMKLMem(const TShape &shape, int dtype);
+    // In the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
+    // save the result in shandle.
+    void Reorder2Default();
+    bool IsMKLDNN() const;
+    bool IsDefault() const;
+#endif
+
     // create storage handle for aux data based on shape
     // this function assumes ctx, aux shapes and aux types are set
     // aux shape is also updated
@@ -862,45 +912,11 @@ class NDArray {
       set_aux_shape(i, shape);
     }
     /*! \brief destructor */
-    ~Chunk() {
-      bool skip_free = static_data || delay_alloc;
-      Storage::Handle h = this->shandle;
-      std::vector<Storage::Handle> aux_h = this->aux_handles;
-      Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) {
-        if (skip_free == false) {
-          Storage::Get()->Free(h);
-          for (size_t i = 0; i < aux_h.size(); i++) {
-            if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]);
-          }
-        }
-      }, shandle.ctx, var);
-    }
+    ~Chunk();
   };  // struct Chunk
 
-  void SetTBlob() const {
-    CHECK(ptr_ != nullptr);
-    TShape shape = shape_;
-    char *dptr = static_cast<char*>(ptr_->shandle.dptr);
-    auto stype = storage_type();
-    if (stype == kDefaultStorage) {
-      dptr += byte_offset_;
-    } else if (stype == kCSRStorage || stype == kRowSparseStorage) {
-      shape = storage_shape();
-    } else {
-      LOG(FATAL) << "unknown storage type " << stype;
-    }
-    tblob_.dptr_ = dptr;
-    tblob_.shape_ = shape;
-    tblob_.type_flag_ = dtype_;
-    tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id);
-#if MKL_EXPERIMENTAL == 1
-    tblob_.Mkl_mem_ = Mkl_mem_;
-#endif
-  }
+  void SetTBlob() const;
 
-#if MKL_EXPERIMENTAL == 1
-  std::shared_ptr<MKLMemHolder> Mkl_mem_;
-#endif
   /*! \brief internal data of NDArray */
   std::shared_ptr<Chunk> ptr_{nullptr};
   /*! \brief shape of current NDArray */
@@ -909,6 +925,8 @@ class NDArray {
   size_t byte_offset_ = 0;
   /*! \brief type of data */
   int dtype_ = -1;
+  /*! \brief whether the NDArray uses memory of another NDArray. */
+  bool reuse_ = false;
   /*! \brief storage type of data */
   NDArrayStorageType storage_type_ = kUndefinedStorage;
   /*! \brief node entry for autograd */
diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h
index b65cd2b..168ddcc 100755
--- a/include/mxnet/tensor_blob.h
+++ b/include/mxnet/tensor_blob.h
@@ -36,9 +36,6 @@
 #include <utility>
 #include <algorithm>
 #include "./base.h"
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#endif
 namespace mxnet {
 
 /* Forward declaration for friend declaration in TBlob */
@@ -66,17 +63,10 @@ class TBlob {
   /*! \brief type flag of the tensor blob */
   int type_flag_;
 
-  /*! \brief storing mkl chunk buffer blob, use for experimental only */
-#if MKL_EXPERIMENTAL == 1
-  std::shared_ptr<MKLMemHolder> Mkl_mem_;
-#endif
   /*! \brief default constructor, default copy assign will work */
   TBlob(void)
       : dptr_(NULL),
         type_flag_(mshadow::DataType<real_t>::kFlag) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = NULL;
-#endif
     SetDLTensor(cpu::kDevMask, 0);
   }
   /*!
@@ -90,9 +80,6 @@ class TBlob {
   TBlob(DType *dptr, const TShape &shape, int dev_mask, int dev_id = -1)
       : dptr_(dptr), shape_(shape),
         type_flag_(mshadow::DataType<DType>::kFlag) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = NULL;
-#endif
     SetDLTensor(dev_mask, dev_id);
   }
   /*!
@@ -105,9 +92,6 @@ class TBlob {
    */
   TBlob(void *dptr, const TShape &shape, int dev_mask, int type_flag, int dev_id = -1)
       : dptr_(dptr), shape_(shape), type_flag_(type_flag) {
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = NULL;
-#endif
     SetDLTensor(dev_mask, dev_id);
   }
   /*!
@@ -135,9 +119,6 @@ class TBlob {
     shape_ = src.shape_;
     type_flag_ = mshadow::DataType<DType>::kFlag;
     SetDLTensor(Device::kDevMask, -1);
-#if MKL_EXPERIMENTAL == 1
-    Mkl_mem_ = NULL;
-#endif
     return *this;
   }
   /*!
@@ -172,11 +153,6 @@ class TBlob {
     CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
       << "TBlob.get_with_shape: data type do not match specified type."
       << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
-#if MKL_EXPERIMENTAL == 1
-    if (Mkl_mem_ != nullptr) {
-      Mkl_mem_->check_and_prv_to_cpu(dptr_);
-    }
-#endif
     return mshadow::Tensor<Device, 2, DType>(static_cast<DType*>(dptr_),
                                              shape_.FlatTo2D(),
                                              shape_[shape_.ndim() - 1],
@@ -217,11 +193,6 @@ class TBlob {
     CHECK(mshadow::DataType<DType>::kFlag == type_flag_)
       << "TBlob.get_with_shape: data type do not match specified type."
       << "Expected: " << type_flag_ << " v.s. given " << mshadow::DataType<DType>::kFlag;
-#if MKL_EXPERIMENTAL == 1
-    if (Mkl_mem_ != nullptr) {
-      Mkl_mem_->check_and_prv_to_cpu(dptr_);
-    }
-#endif
     return static_cast<DType*>(dptr_);
   }
   /*! \brief device mask of the corresponding device */
diff --git a/prepare_mkldnn.sh b/prepare_mkldnn.sh
new file mode 100755
index 0000000..7cd7d6a
--- /dev/null
+++ b/prepare_mkldnn.sh
@@ -0,0 +1,118 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# set -ex
+#
+# All modification made by Intel Corporation: © 2016 Intel Corporation
+#
+# All contributions by the University of California:
+# Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
+# All rights reserved.
+#
+# All other contributions:
+# Copyright (c) 2014, 2015, the respective contributors
+# All rights reserved.
+# For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md
+#
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+#     * Redistributions of source code must retain the above copyright notice,
+#       this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above copyright
+#       notice, this list of conditions and the following disclaimer in the
+#       documentation and/or other materials provided with the distribution.
+#     * Neither the name of Intel Corporation nor the names of its contributors
+#       may be used to endorse or promote products derived from this software
+#       without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
+# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#
+
+MXNET_ROOTDIR="$(pwd)"
+MKLDNN_ROOTDIR="$MXNET_ROOTDIR/3rdparty/mkldnn/"
+MKLDNN_SRCDIR="$MKLDNN_ROOTDIR/src"
+MKLDNN_BUILDDIR="$MKLDNN_ROOTDIR/build"
+MKLDNN_INSTALLDIR="$MKLDNN_ROOTDIR/install"
+MKLDNN_LIBDIR="$MXNET_ROOTDIR/lib"
+
+# MKLDNN install destination
+HOME_MKLDNN=$1
+if [ ! -z "$HOME_MKLDNN" ]; then
+  mkdir -p $HOME_MKLDNN
+  if [ ! -w $HOME_MKLDNN ]; then
+    echo "MKLDNN install to $HOME_MKLDNN failed, please try with sudo" >&2
+    exit 1
+  fi
+fi
+
+if [ -z $MKLDNNROOT ]; then
+if [ ! -f "$MKLDNN_INSTALLDIR/lib/libmkldnn.so" ]; then
+    mkdir -p $MKLDNN_INSTALLDIR
+	cd $MKLDNN_ROOTDIR
+    if [ -z $MKLROOT ] && [ ! -f $MKLDNN_INSTALLDIR/include/mkl_cblas.h ]; then
+        rm -rf external && cd scripts && ./prepare_mkl.sh && cd ..
+        cp -a external/*/* $MKLDNN_INSTALLDIR/.
+    fi 
+    echo "Building MKLDNN ..." >&2
+    cd $MXNET_ROOTDIR
+	g++ --version >&2
+    if [ -z $ARCH_OPT ]; then
+        cmake $MKLDNN_ROOTDIR -DCMAKE_INSTALL_PREFIX=$MKLDNN_INSTALLDIR -B$MKLDNN_BUILDDIR
+    else
+        cmake $MKLDNN_ROOTDIR -DCMAKE_INSTALL_PREFIX=$MKLDNN_INSTALLDIR -B$MKLDNN_BUILDDIR -DARCH_OPT_FLAGS=$ARCH_OPT
+    fi
+    make -C $MKLDNN_BUILDDIR -j$(cat /proc/cpuinfo | grep processor | wc -l) VERBOSE=1 >&2
+    make -C $MKLDNN_BUILDDIR install
+    rm -rf $MKLDNN_BUILDDIR
+    mkdir -p $MKLDNN_LIBDIR
+    cp $MKLDNN_INSTALLDIR/lib/* $MKLDNN_LIBDIR
+fi
+MKLDNNROOT=$MKLDNN_INSTALLDIR
+fi
+
+if [ -z $MKLROOT ] && [ -f $MKLDNNROOT/include/mkl_cblas.h ]; then 
+  MKLROOT=$MKLDNNROOT;
+fi
+
+# user specified MKLDNN install folder
+if [ -d "$HOME_MKLDNN" ]; then
+  # skip if user specificed MKLDNNROOT
+  [ "$MKLDNNROOT" != "$HOME_MKLDNN" ] && rsync -a $MKLDNNROOT/include $MKLDNNROOT/lib $HOME_MKLDNN/.
+  [ "$MKLROOT" != "$HOME_MKLDNN" ] && rsync -a $MKLROOT/include $MKLROOT/lib $HOME_MKLDNN/.
+  # update ldconfig if possible
+  if [ -w /etc/ld.so.conf.d ]; then
+    echo "$HOME_MKLDNN/lib" > /etc/ld.so.conf.d/mxnmkldnn.conf && ldconfig
+  fi
+# return value to calling script (Makefile,cmake)
+  echo $HOME_MKLDNN $HOME_MKLDNN
+else
+  echo $MKLDNNROOT $MKLROOT
+fi
+
diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py
index 6461904..56f4b9c 100644
--- a/python/mxnet/test_utils.py
+++ b/python/mxnet/test_utils.py
@@ -1287,6 +1287,10 @@ def check_consistency(sym, ctx_list, scale=1.0, grad_req='write',
             arr[:] = arg_params[name]
         for name, arr in exe.aux_dict.items():
             arr[:] = aux_params[name]
+        # We need to initialize the gradient arrays if it's add.
+        if (grad_req == "add"):
+            for arr in exe.grad_arrays:
+                arr[:] = np.zeros(arr.shape, dtype=arr.dtype)
 
     dtypes = [np.dtype(exe.outputs[0].dtype) for exe in exe_list]
     max_idx = np.argmax(dtypes)
diff --git a/src/common/exec_utils.h b/src/common/exec_utils.h
index dcd1504..5fd1a9b 100644
--- a/src/common/exec_utils.h
+++ b/src/common/exec_utils.h
@@ -43,19 +43,61 @@ namespace common {
           indices are not recorded
  * \return true if any source NDArray need to cast storage
  */
-inline bool SetupDefaultBlobs(const std::vector<NDArray>& src,
-                              std::vector<TBlob> *blobs,
-                              std::vector<NDArray> *temp_src,
-                              std::vector<NDArray> *temp_dst,
-                              std::unordered_map<uint32_t, uint32_t> *idx_map = nullptr) {
+inline bool SetupDefaultBlobsIn(const std::vector<NDArray>& src,
+                                const std::vector<NDArray> *bufs,
+                                std::vector<TBlob> *blobs,
+                                std::vector<NDArray> *temp_src,
+                                std::vector<NDArray> *temp_dst,
+                                std::unordered_map<uint32_t, uint32_t> *idx_map) {
   bool require_cast = false;
   for (size_t i = 0; i < src.size(); i++) {
     auto& nd = src[i];
-    if (nd.storage_type() != kDefaultStorage) {
-      if (idx_map != nullptr) {
-        (*idx_map)[i] = temp_dst->size();
-      }
-      NDArray temp(nd.shape(), nd.ctx(), false, nd.dtype());
+    bool is_default = nd.storage_type() == kDefaultStorage;
+#if MXNET_USE_MKLDNN == 1
+    // We have to make sure it's default storage and default layout.
+    is_default = nd.IsDefaultData();
+#endif
+    if (!is_default) {
+      (*idx_map)[i] = temp_dst->size();
+      NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
+                                                             true, nd.dtype());
+#if MXNET_USE_MKLDNN == 1
+      CHECK(temp.IsDefaultData());
+#endif
+      temp_src->emplace_back(nd);
+      temp_dst->emplace_back(temp);
+      blobs->emplace_back(temp.data());
+      require_cast = true;
+    } else {
+      blobs->push_back(nd.data());
+    }
+  }
+  return require_cast;
+}
+
+inline bool SetupDefaultBlobsOut(const std::vector<NDArray>& src,
+                                 const std::vector<OpReqType> &req,
+                                 const std::vector<NDArray> *bufs,
+                                 std::vector<TBlob> *blobs,
+                                 std::vector<NDArray> *temp_src,
+                                 std::vector<NDArray> *temp_dst) {
+  bool require_cast = false;
+  for (size_t i = 0; i < src.size(); i++) {
+    auto& nd = src[i];
+    bool is_default = nd.storage_type() == kDefaultStorage;
+#if MXNET_USE_MKLDNN == 1
+    // If it's writeTo, we don't need to worry whether it contains valid data.
+    if (req[i] == kWriteTo && is_default)
+      const_cast<NDArray &>(nd).InvalidateMKLDNNData();
+    // We have to make sure it's default storage and default layout.
+    is_default = nd.IsDefaultData();
+#endif
+    if (!is_default) {
+      NDArray temp = bufs != nullptr ? bufs->at(i) : NDArray(nd.shape(), nd.ctx(),
+                                                             true, nd.dtype());
+#if MXNET_USE_MKLDNN == 1
+      CHECK(temp.IsDefaultData());
+#endif
       temp_src->emplace_back(nd);
       temp_dst->emplace_back(temp);
       blobs->emplace_back(temp.data());
@@ -76,6 +118,9 @@ inline bool SetupDefaultBlobs(const std::vector<NDArray>& src,
  */
 inline void SetupDefaultBlobsInOut(const std::vector<NDArray> &ndinputs,
                                    const std::vector<NDArray> &ndoutputs,
+                                   const std::vector<OpReqType> &req,
+                                   const std::vector<NDArray> *in_bufs,
+                                   const std::vector<NDArray> *out_bufs,
                                    std::vector<TBlob> *input_blobs,
                                    std::vector<TBlob> *output_blobs,
                                    std::vector<NDArray> *pre_temp_src,
@@ -85,9 +130,11 @@ inline void SetupDefaultBlobsInOut(const std::vector<NDArray> &ndinputs,
                                    std::unordered_map<uint32_t, uint32_t> *in_temp_idx_map,
                                    const std::vector<uint32_t> &mutate_idx) {
   // populate input blobs
-  SetupDefaultBlobs(ndinputs, input_blobs, pre_temp_src, pre_temp_dst, in_temp_idx_map);
+  SetupDefaultBlobsIn(ndinputs, in_bufs, input_blobs, pre_temp_src, pre_temp_dst,
+                      in_temp_idx_map);
   // populate output blobs
-  SetupDefaultBlobs(ndoutputs, output_blobs, post_temp_dst, post_temp_src);
+  SetupDefaultBlobsOut(ndoutputs, req, out_bufs, output_blobs, post_temp_dst,
+                       post_temp_src);
   // add mutable inputs to post temp list
   for (const auto idx : mutate_idx) {
     auto map_iter = in_temp_idx_map->find(idx);
diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc
index 1bcc40a..e4d4955 100644
--- a/src/executor/attach_op_execs_pass.cc
+++ b/src/executor/attach_op_execs_pass.cc
@@ -30,11 +30,8 @@
 #include "../common/utils.h"
 #include "../common/exec_utils.h"
 #include "./exec_pass.h"
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "../operator/mkl/mkl_memory-inl.h"
-#include "../operator/mkl/mkl_util-inl.h"
-#endif
+#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
+
 namespace mxnet {
 
 namespace op {
@@ -58,23 +55,34 @@ class StorageFallbackOpExecutor : public OpExecutor {
  protected:
   // initialize the data blobs
   void InitBlobs() {
-    using namespace common;
     if (!init_) {
-      in_data_.clear(); out_data_.clear();
-      pre_temp_src_.clear(); pre_temp_dst_.clear();
-      post_temp_src_.clear(); post_temp_dst_.clear();
-      in_temp_idx_map_.clear();
-      SetupDefaultBlobsInOut(in_array, out_array, &in_data_, &out_data_,
-                             &pre_temp_src_, &pre_temp_dst_,
-                             &post_temp_src_, &post_temp_dst_,
-                             &in_temp_idx_map_, mutate_idx_);
+      pre_temp_buf_.clear();
+      post_temp_buf_.clear();
+      for (size_t i = 0; i < in_array.size(); i++) {
+        auto &nd = in_array[i];
+        pre_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype());
+      }
+      for (size_t i = 0; i < out_array.size(); i++) {
+        auto &nd = out_array[i];
+        post_temp_buf_.emplace_back(nd.shape(), nd.ctx(), true, nd.dtype());
+      }
       init_ = true;
     }
   }
 
   // storage fallback before fcompute is launched
   void PreFCompute(bool is_gpu) {
+    using namespace common;
     InitBlobs();
+    in_data_.clear(); out_data_.clear();
+    pre_temp_src_.clear(); pre_temp_dst_.clear();
+    post_temp_src_.clear(); post_temp_dst_.clear();
+    in_temp_idx_map_.clear();
+    SetupDefaultBlobsInOut(in_array, out_array, req, &pre_temp_buf_, &post_temp_buf_,
+                           &in_data_, &out_data_,
+                           &pre_temp_src_, &pre_temp_dst_,
+                           &post_temp_src_, &post_temp_dst_,
+                           &in_temp_idx_map_, mutate_idx_);
     common::CastNonDefaultStorage(pre_temp_src_, pre_temp_dst_, op_ctx, is_gpu);
   }
 
@@ -85,6 +93,8 @@ class StorageFallbackOpExecutor : public OpExecutor {
 
   // default storage tensor blobs for fcompute
   std::vector<TBlob> in_data_, out_data_;
+  // These are NDArray buffers for cast storage.
+  std::vector<NDArray> pre_temp_buf_, post_temp_buf_;
   // source NDArray for cast storage
   std::vector<NDArray> pre_temp_src_, post_temp_src_;
   // destination NDArray for cast storage
@@ -106,10 +116,6 @@ class StatefulComputeExecutor : public StorageFallbackOpExecutor {
     PreFCompute(is_gpu);
     fcompute_(state_, op_ctx, in_data_, req, out_data_);
     PostFCompute(is_gpu);
-#if MKL_EXPERIMENTAL == 1
-    mkl_tblobs_prv_to_cpu(in_data_);
-    mkl_tblobs_prv_to_cpu(out_data_);
-#endif
   }
 
   ExecType exec_type() const override {
@@ -175,10 +181,6 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
     PreFCompute(is_gpu);
     fcompute_(attrs_, op_ctx, in_data_, req, out_data_);
     PostFCompute(is_gpu);
-#if MKL_EXPERIMENTAL == 1
-    mkl_tblobs_prv_to_cpu(in_data_);
-    mkl_tblobs_prv_to_cpu(out_data_);
-#endif
   }
 
   ExecType exec_type() const override {
@@ -202,6 +204,9 @@ class FComputeExExecutor : public OpExecutor {
  public:
   void Run(RunContext rctx, bool is_gpu) override {
     op_ctx.run_ctx = rctx;
+#if MXNET_USE_MKLDNN == 1
+    InvalidateOutputs(out_array, req);
+#endif
     fcompute_(attrs_, op_ctx, in_array, req, out_array);
   }
 
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 2a7d2b9..f685370 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1209,7 +1209,8 @@ void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
       const NDArray& src = data_pool_.at(storage_id);
       data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
     } else {
-      data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]);
+      data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
+                               true, vdtype[i]);
     }
     if (log_verbose_) {
       LOG(INFO) << "\tinit data entry\t" << i << "\tas " << common::stype_string(storage_type);
diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc
index 73a34c8..01fab22 100644
--- a/src/executor/infer_graph_attr_pass.cc
+++ b/src/executor/infer_graph_attr_pass.cc
@@ -423,11 +423,6 @@ nnvm::Graph InferStorageType(nnvm::Graph&& graph,
     DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined);
     graph.attrs["dispatch_mode"] = std::make_shared<any>(std::move(dispatch_modes));
   }
-  // initialize unknown values for dispatch modes
-  if (graph.attrs.count("dispatch_mode") == 0) {
-    DispatchModeVector dispatch_modes(graph.indexed_graph().num_nodes(), DispatchMode::kUndefined);
-    graph.attrs["dispatch_mode"] = std::make_shared<any>(std::move(dispatch_modes));
-  }
   // initialize the dev_mask vector from the context vector
   if (graph.attrs.count("dev_mask") == 0) {
     CHECK_GT(graph.attrs.count("context"), 0);
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index eaa95a5..93a8bc6 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -214,6 +214,12 @@ nnvm::Graph Imperative::CachedOp::GetForwardGraph(
 
   StorageVector storage(idx.num_node_entries(), exec::kBadStorageID);
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
+  const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
+  CHECK_EQ(stypes.size(), storage.size());
+  for (size_t i = 0; i < stypes.size(); i++) {
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
+  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> >(
@@ -320,6 +326,10 @@ nnvm::Graph Imperative::CachedOp::GetBackwardGraph(
   for (size_t i = 0; i < num_forward_entries; ++i) storage[i] = exec::kExternalStorageID;
   for (const auto i : idx.input_nodes()) storage[idx.entry_id(i, 0)] = exec::kExternalStorageID;
   for (const auto i : idx.outputs()) storage[idx.entry_id(i)] = exec::kExternalStorageID;
+  for (size_t i = 0; i < stypes.size(); i++) {
+    if (stypes[i] != kDefaultStorage)
+      storage[i] = exec::kDynamicStorageID;
+  }
 
   auto mem_plan = PlanMemory(
       &g, std::move(storage), g.GetAttr<std::vector<uint32_t> >("backward_ref_count"),
diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h
index fc28f50..966a753 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -362,9 +362,9 @@ inline void PushFCompute(const FCompute& fn,
       // mapping from index in input_blobs to index in pre_temp_dst
       std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
       // setup blobs
-      SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs,
-                             &pre_temp_src, &pre_temp_dst, &post_temp_src,
-                             &post_temp_dst, &in_temp_idx_map, mutate_idx);
+      SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr,
+                             &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst,
+                             &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx);
       // setup context
       OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested};
       bool is_gpu = ctx.dev_mask() == gpu::kDevMask;
@@ -460,9 +460,9 @@ inline void PushOperator(const OpStatePtr& state,
         // mapping from index in input_blobs to index in pre_temp_dst
         std::unordered_map<uint32_t, uint32_t> in_temp_idx_map;
         // populate input blobs and output blobs
-        SetupDefaultBlobsInOut(inputs, outputs, &input_blobs, &output_blobs,
-                               &pre_temp_src, &pre_temp_dst, &post_temp_src, &post_temp_dst,
-                               &in_temp_idx_map, mutate_idx);
+        SetupDefaultBlobsInOut(inputs, outputs, req, nullptr, nullptr,
+                               &input_blobs, &output_blobs, &pre_temp_src, &pre_temp_dst,
+                               &post_temp_src, &post_temp_dst, &in_temp_idx_map, mutate_idx);
         // setup contexts
         bool is_gpu = rctx.get_ctx().dev_mask() == gpu::kDevMask;
         // pre-fcompute fallback
@@ -607,6 +607,7 @@ inline bool CheckAndInferStorageType(nnvm::Graph* p_g, exec::DevMaskVector&& dev
     }
     if (match) return true;
   }
+  g.attrs.erase("dispatch_mode");
   g.attrs.erase("storage_type");
   g.attrs.erase("storage_type_inputs");
   if (node_range.second > node_range.first) {
diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h
index e98102b..e01cc42 100644
--- a/src/kvstore/kvstore_dist.h
+++ b/src/kvstore/kvstore_dist.h
@@ -32,11 +32,6 @@
 #include "mxnet/engine.h"
 #include "ps/ps.h"
 #include "./kvstore_dist_server.h"
-#if MKL_EXPERIMENTAL == 1
-#include <mkl_memory.h>
-#include "../operator/mkl/mkl_memory-inl.h"
-#include "../operator/mkl/mkl_util-inl.h"
-#endif
 namespace mxnet {
 namespace kvstore {
 
@@ -237,9 +232,6 @@ class KVStoreDist : public KVStoreLocal {
         PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
                       EncodeDefaultKey(key, size, false) :
                       EncodeCompressedKey(key, size, false);
-#if MKL_EXPERIMENTAL == 1
-        mkl_set_tblob_eager_mode(recv_buf.data());
-#endif
         real_t* data = recv_buf.data().dptr<real_t>();
         // false means not to delete data when SArray is deleted
         auto vals = new ps::SArray<real_t>(data, size, false);
@@ -389,9 +381,6 @@ class KVStoreDist : public KVStoreLocal {
       [this, key, pskv, small_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
         size_t size = small_buf.shape().Size();
         real_t* data = small_buf.data().dptr<real_t>();
-#if MKL_EXPERIMENTAL == 1
-        mkl_set_tblob_eager_mode(small_buf.data());
-#endif
         // do push. false means no delete
         ps::SArray<real_t> vals(data, size, false);
         CHECK_NOTNULL(ps_worker_)->ZPush(
@@ -416,9 +405,6 @@ class KVStoreDist : public KVStoreLocal {
           // convert to ps keys
           size_t size = send_buf.shape().Size();
           real_t* data = send_buf.data().dptr<real_t>();
-#if MKL_EXPERIMENTAL == 1
-          mkl_set_tblob_eager_mode(send_buf.data());
-#endif
           // do push. false means no delete
           ps::SArray<real_t> vals(data, size, false);
           CHECK_NOTNULL(ps_worker_)->ZPush(
@@ -440,9 +426,6 @@ class KVStoreDist : public KVStoreLocal {
     using namespace rowsparse;
     auto push_to_servers = [this, key, send_buf]
                            (RunContext rctx, Engine::CallbackOnComplete cb) {
-#if MKL_EXPERIMENTAL == 1
-      mkl_set_tblob_eager_mode(send_buf.data());
-#endif
       real_t* data = send_buf.data().dptr<real_t>();
       const int64_t num_rows = send_buf.aux_shape(kIdx)[0];
       const auto offsets = send_buf.aux_data(kIdx).dptr<int64_t>();
@@ -481,9 +464,6 @@ class KVStoreDist : public KVStoreLocal {
       // allocate memory for the buffer
       size_t num_rows = indices.shape().Size();
       recv_buf.CheckAndAlloc({mshadow::Shape1(num_rows)});
-#if MKL_EXPERIMENTAL == 1
-      mkl_set_tblob_eager_mode(recv_buf.data());
-#endif
       real_t* data = recv_buf.data().dptr<real_t>();
       const auto offsets = indices.data().dptr<int64_t>();
       const auto unit_len = recv_buf.shape().ProdShape(1, recv_buf.shape().ndim());
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 4db314f..ae7209e 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -31,10 +31,14 @@
 #include <mxnet/resource.h>
 #include <mxnet/imperative.h>
 #include <mshadow/tensor.h>
+#if MXNET_USE_MKLDNN == 1
+#include <mkldnn.hpp>
+#endif
 #include "./ndarray_function.h"
 #include "../common/utils.h"
 #include "../operator/tensor/matrix_op-inl.h"
 #include "../operator/tensor/init_op.h"
+#include "../operator/nn/mkldnn/mkldnn_base-inl.h"
 
 #if MXNET_USE_OPENCV
 #include <opencv2/opencv.hpp>
@@ -46,6 +50,104 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg);
 
 namespace mxnet {
 
+NDArray::NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx,
+    bool delay_alloc, int dtype, std::vector<int> aux_types,
+    std::vector<TShape> aux_shapes, TShape storage_shape) : shape_(shape),
+  dtype_(dtype), storage_type_(stype), entry_({nullptr, 0, 0}) {
+  // Assign default aux types if not given
+  if (aux_types.size() == 0
+      && stype != kDefaultStorage) {
+    if (stype == kRowSparseStorage) {
+      aux_types = {mshadow::kInt64};
+    } else if (stype == kCSRStorage) {
+      aux_types = {mshadow::kInt64, mshadow::kInt64};
+    } else {
+      LOG(FATAL) << "Unknown storage type " << stype;
+    }
+  }
+  // Assign default shapes if not given
+  // unknown shapes are intialized as {0} such that Size() would return 0
+  if (aux_shapes.size() == 0
+      && stype != kDefaultStorage) {
+    if (stype == kRowSparseStorage) {
+      aux_shapes = {TShape(mshadow::Shape1(0))};
+    } else if (stype == kCSRStorage) {
+      // aux shapes for indptr and indices
+      aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))};
+    } else {
+      LOG(FATAL) << "Unknown storage type " << stype;
+    }
+  }
+  if (storage_shape.Size() == 0
+      && stype != kDefaultStorage) {
+    if (stype == kRowSparseStorage) {
+      storage_shape = shape;
+      storage_shape[0] = aux_shapes[rowsparse::kIdx][0];
+    } else if (stype == kCSRStorage) {
+      storage_shape = aux_shapes[csr::kIdx];
+    } else {
+      LOG(FATAL) << "Unknown storage type " << stype;
+    }
+  }
+  if (stype == kDefaultStorage)
+    ptr_ = std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype);
+  else
+    ptr_ = std::make_shared<Chunk>(stype, storage_shape, ctx, delay_alloc,
+        dtype, aux_types, aux_shapes);
+}
+
+struct ChunkMem {
+  Storage::Handle h;
+  std::vector<Storage::Handle> aux_h;
+#if MXNET_USE_MKLDNN == 1
+  std::shared_ptr<mkldnn::memory> mem;
+#endif
+};
+
+NDArray::Chunk::~Chunk() {
+  bool skip_free = static_data || delay_alloc;
+  ChunkMem mem;
+  mem.h = this->shandle;
+  mem.aux_h = this->aux_handles;
+#if MXNET_USE_MKLDNN == 1
+  // We want to delete mkldnn memory after deleting the variable.
+  mem.mem = this->mkl_mem_;
+#endif
+  Engine::Get()->DeleteVariable([mem, skip_free](RunContext s) {
+    if (skip_free == false) {
+#if MXNET_USE_MKLDNN == 1
+      if (mem.mem) {
+        CHECK_LE(mem.mem->get_primitive_desc().get_size(), mem.h.size);
+        CHECK_EQ(mem.mem->get_data_handle(), mem.h.dptr);
+      }
+#endif
+      if (mem.h.size > 0) Storage::Get()->Free(mem.h);
+      for (size_t i = 0; i < mem.aux_h.size(); i++) {
+        if (mem.aux_h[i].size > 0) Storage::Get()->Free(mem.aux_h[i]);
+      }
+    }
+  }, shandle.ctx, var);
+}
+
+void NDArray::Chunk::CheckAndAllocData(const TShape &shape, int dtype) {
+  CHECK_NE(aux_shapes.size(), 0)
+      << "data is expected to be allocated after aux_data";
+  auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
+  if (shandle.size < dbytes) {
+    // free storage if necessary and alloc again
+    if (shandle.size > 0) Storage::Get()->Free(shandle);
+    // init storage
+    shandle = Storage::Get()->Alloc(dbytes, ctx);
+#if MXNET_USE_MKLDNN == 1
+    mkl_mem_ = nullptr;
+#endif
+  }
+  // init shape
+  storage_shape = shape;
+  // delay_alloc is only set when data storage handle is present
+  delay_alloc = false;
+}
+
 NDArray NDArray::grad() const {
   if (Imperative::AGInfo::IsNone(*this)) return NDArray();
   Imperative::AGInfo& info = Imperative::AGInfo::Get(entry_.node);
@@ -64,15 +166,55 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
   return ret;
 }
 
+#if MXNET_USE_MKLDNN == 1
+
+NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
+  CHECK(!is_none()) << "NDArray is not initialized";
+  CHECK_GE(shape_.Size(), shape.Size())
+    << "NDArray.Reshape: target shape size is larger current shape";
+  CHECK_EQ(storage_type(), kDefaultStorage);
+  if (!IsMKLDNNData()) {
+    NDArray ret = this->Detach();
+    ret.shape_ = shape;
+    return ret;
+  } else {
+    NDArray ret(shape, ctx(), true, dtype());
+    // We shouldn't submit the reorder primitive here because submit will
+    // be called in operators.
+    auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
+    CHECK_NE(format, ptr_->mkl_mem_->get_primitive_desc().desc().data.format);
+    auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format);
+    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    MKLDNNStream *stream = MKLDNNStream::Get();
+    stream->RegisterMem(ptr_->mkl_mem_);
+    stream->RegisterPrim(mkldnn::reorder(*ptr_->mkl_mem_, *def_mem));
+    // def_mem points to a memory region in the temp space. It's only valid
+    // inside an operator. As such, the returned NDArray can only be valid
+    // inside an operator and the shared point doesn't need to do anything
+    // when it's destroyed.
+    ret.ptr_->mkl_mem_ = std::shared_ptr<mkldnn::memory>(def_mem,
+                                                         [](mkldnn::memory *mem){});
+    ret.ptr_->shandle.dptr = def_mem->get_data_handle();
+    ret.ptr_->shandle.size = def_mem->get_primitive_desc().get_size();
+    ret.ptr_->delay_alloc = false;
+    ret.ptr_->static_data = true;
+    ret.byte_offset_ = byte_offset_;
+    return ret;
+  }
+}
+
+#endif
+
 NDArray NDArray::Reshape(const TShape &shape) const {
   CHECK(!is_none()) << "NDArray is not initialized";
-  auto stype = storage_type();
-  // reshape is not supported for non-default ndarray with dismatching shapes
-  CHECK((shape_ == shape) || stype == kDefaultStorage)
-    << "Reshape for storage type " << stype << " is not implemented yet";
   CHECK_GE(shape_.Size(), shape.Size())
     << "NDArray.Reshape: target shape size is larger current shape";
   NDArray ret = this->Detach();
+  // If the shape doesn't change, we can just return it now.
+  if (ret.shape_ == shape)
+    return ret;
+  // Otherwise, reshape only works on the default layout.
+  CHECK_EQ(storage_type(), kDefaultStorage);
   ret.shape_ = shape;
   return ret;
 }
@@ -95,7 +237,6 @@ NDArray NDArray::ReshapeWithRecord(const TShape &shape) {
   return ret;
 }
 
-
 NDArray NDArray::Slice(index_t begin, index_t end) const {
   CHECK(!is_none()) << "NDArray is empty";
   CHECK_LE(begin, end)
@@ -127,8 +268,8 @@ NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
 }
 
 NDArray NDArray::At(index_t idx) const {
-  CHECK(storage_type() == kDefaultStorage) << "Storage type "
-                                           << storage_type() << " doesn't support At()";
+  CHECK(storage_type() == kDefaultStorage)
+      << "Storage type " << storage_type() << " doesn't support At()";
   NDArray ret = this->Slice(idx, idx+1);
   if (shape_.ndim() > 1) {
     return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim()));
@@ -181,6 +322,400 @@ void NDArray::set_fresh_out_grad(bool state) const {
   info.fresh_out_grad = state;
 }
 
+#if MXNET_USE_MKLDNN == 1
+static inline bool same_shape(const TShape &shape, mkldnn_dims_t dims, int ndims) {
+  if (shape.ndim() != (size_t)ndims)
+    return false;
+  for (int i = 0; i < ndims; i++)
+    if (shape[i] != dims[i])
+      return false;
+  return true;
+}
+
+static inline bool same_shape(const TShape &shape, int dtype, mkldnn::memory::desc desc) {
+  return same_shape(shape, desc.data.dims, desc.data.ndims)
+      && get_mkldnn_type(dtype) == desc.data.data_type;
+}
+
+bool NDArray::Chunk::IsMKLDNN() const {
+  if (storage_type != kDefaultStorage)
+    return false;
+  if (mkl_mem_ == nullptr)
+    return false;
+  auto desc = mkl_mem_->get_primitive_desc().desc();
+  return desc.data.format != GetDefaultFormat(desc);
+}
+
+bool NDArray::Chunk::IsDefault() const {
+  if (storage_type != kDefaultStorage)
+    return false;
+  // If we don't have mkldnn memory yet, we just assume it's not the default
+  // format.
+  if (mkl_mem_ == nullptr)
+    return true;
+  auto desc = mkl_mem_->get_primitive_desc().desc();
+  return desc.data.format == GetDefaultFormat(desc);
+}
+
+void NDArray::Chunk::Reorder2Default() {
+  if (mkl_mem_ == nullptr)
+    return;
+
+  auto format = GetDefaultFormat(mkl_mem_->get_primitive_desc().desc());
+  CHECK(format != mkl_mem_->get_primitive_desc().desc().data.format);
+
+  auto def_pd = GetPrimitiveDesc(mkl_mem_->get_primitive_desc(), format);
+  mkldnn_mem_ptr def_mem(new mkldnn::memory(def_pd));
+  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
+  std::vector<mkldnn::primitive> net;
+  net.push_back(mkldnn::reorder(*mkl_mem_, *def_mem));
+  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+
+  CHECK(shandle.size >= def_pd.get_size());
+  CheckAndAlloc(def_pd.get_size());
+  // TODO(zhengda) We need to avoid memory copy here.
+  memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(def_pd, shandle.dptr));
+}
+
+void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
+  // The shape of the array and the one of the MKL memory may mismatch.
+  // For example, if the array stores parameters, the MKL memory may store data
+  // in 5 dimensions while the NDArray stores data in 4 dimensions.
+  if (mkl_mem_ && mkl_mem_->get_data_handle() == shandle.dptr
+      && same_shape(shape, dtype, mkl_mem_->get_primitive_desc().desc())) {
+    return;
+  }
+
+  mkldnn::memory::dims dims;
+  // These are shapes supprted by MKLDNN.
+  if (shape.ndim() == 1 || shape.ndim() == 2 || shape.ndim() == 4
+      || shape.ndim() == 5) {
+    dims.resize(shape.ndim());
+    for (size_t i = 0; i < dims.size(); i++)
+      dims[i] = shape[i];
+  } else if (shape.ndim() == 3) {
+    // If there are 3 dimensions, we'll force it to 4 dimensions.
+    dims.resize(shape.ndim() + 1);
+    dims[0] = 1;
+    for (size_t i = 0; i < shape.ndim(); i++)
+      dims[i + 1] = shape[i];
+  } else {
+    LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
+  }
+  mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
+  switch (dims.size()) {
+    case 1: layout = mkldnn::memory::format::x; break;
+    case 2: layout = mkldnn::memory::format::nc; break;
+    case 4: layout = mkldnn::memory::format::nchw; break;
+    // This isn't the right layout when the data has 5 dimensions in MXNet.
+    // MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
+    // a corresponding format.
+    case 5: layout = mkldnn::memory::format::goihw; break;
+  }
+  mkldnn::memory::desc data_md{dims, get_mkldnn_type(dtype), layout};
+  auto cpu_engine = CpuEngine::Get()->get_engine();
+  if (shandle.dptr == nullptr) {
+    CHECK(delay_alloc);
+    CheckAndAlloc();
+  }
+  mkldnn::memory::primitive_desc pd(data_md, cpu_engine);
+  CHECK(shandle.size >= pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
+}
+
+/*
+ * Here we want to get MKLDNN memory whose primitive desc is exactly the same as
+ * the given one. operator== can't guarantee that. == can return true even if
+ * the formats are different. I need to double check its format.
+ */
+static inline mkldnn::memory *GetMKLDNNExact(
+    const mkldnn::memory *mem, mkldnn::memory::primitive_desc desc) {
+  auto src_desc = mem->get_primitive_desc();
+  if (desc == src_desc && desc.desc().data.format == src_desc.desc().data.format) {
+    return const_cast<mkldnn::memory *>(mem);
+  } else {
+    std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(
+            desc, mem->get_data_handle()));
+    MKLDNNStream::Get()->RegisterMem(ret);
+    return ret.get();
+  }
+}
+
+const mkldnn::memory *NDArray::GetMKLDNNData(
+    const mkldnn::memory::primitive_desc &desc) const {
+  if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
+    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
+    return nullptr;
+  }
+  auto mem = GetMKLDNNData();
+  mkldnn::memory::primitive_desc _desc = desc;
+  auto desc1 = mem->get_primitive_desc().desc();
+  auto desc2 = _desc.desc();
+  // The MKL memory has the same format and shape as required,
+  // or both use the default format, we can return the MKL memory.
+  if (mem->get_primitive_desc() == desc
+      || (desc1.data.format == GetDefaultFormat(desc1)
+        && desc2.data.format == GetDefaultFormat(desc2))) {
+    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+  } else {
+    return nullptr;
+  }
+}
+
+const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
+    const mkldnn::memory::primitive_desc &desc) const {
+  if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
+    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
+    return nullptr;
+  }
+  CHECK(storage_type() == kDefaultStorage);
+
+  auto mem = GetMKLDNNData();
+  // If the memory descriptor matches, it's easy.
+  MKLDNNStream *stream = MKLDNNStream::Get();
+  if (mem->get_primitive_desc() == desc) {
+    return GetMKLDNNExact(mem, desc);
+  }
+
+  mkldnn::memory::primitive_desc _desc = desc;
+  // Now we need to determine if we should reorder the memory.
+  // If both use the default formats, we think we don't need to reorder.
+  auto desc1 = mem->get_primitive_desc().desc();
+  auto desc2 = _desc.desc();
+  if (desc1.data.format == GetDefaultFormat(desc1) &&
+      desc2.data.format == GetDefaultFormat(desc2)) {
+    mkldnn_mem_ptr ret(new mkldnn::memory(desc, mem->get_data_handle()));
+    stream->RegisterMem(ret);
+    return ret.get();
+  } else {
+    auto ret = TmpMemMgr::Get()->Alloc(desc);
+    stream->RegisterPrim(mkldnn::reorder(*mem, *ret));
+    return ret;
+  }
+}
+
+const mkldnn::memory *NDArray::GetMKLDNNData() const {
+  CHECK(storage_type() == kDefaultStorage);
+  // If this array uses MKLDNN layout and it's a view, we have to change its
+  // layout to the default layout.
+  if (IsMKLDNNData() && IsView())
+    ptr_->Reorder2Default();
+  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
+  // If shandle has data, the data in shandle and mkl_mem_ should match.
+  if (ptr_->shandle.dptr)
+    CHECK(ptr_->shandle.dptr == ptr_->mkl_mem_->get_data_handle());
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
+  auto pd = ptr_->mkl_mem_->get_primitive_desc();
+  if (IsView()) {
+    // Sliced array must use the default layout.
+    CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
+  }
+  if (IsView()) {
+    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->get_data_handle())
+        + byte_offset_;
+
+    // Create the primitive desc for the new mkldnn memory.
+    mkldnn::memory::dims dims(shape().ndim());
+    for (size_t i = 0; i < dims.size(); i++)
+      dims[i] = shape()[i];
+    mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(
+        GetDefaultFormat(shape().ndim()));
+    mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
+        pd.desc().data.data_type);
+    mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
+    mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine());
+
+    std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr));
+    MKLDNNStream::Get()->RegisterMem(ret);
+    return ret.get();
+  } else {
+    return ptr_->mkl_mem_.get();
+  }
+}
+
+void NDArray::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) {
+  CHECK_EQ(storage_type(), kDefaultStorage);
+  // If the memory already uses the specified layout, don't do anything.
+  if (ptr_->mkl_mem_ != nullptr && ptr_->mkl_mem_->get_primitive_desc() == pd)
+    return;
+  auto _pd = pd;
+  auto _desc = _pd.desc();
+  auto def_format = GetDefaultFormat(_desc);
+  // If the memory is default, don't do anything.
+  if (def_format == _desc.data.format && ptr_->IsDefault())
+    return;
+  // If the specified layout is default, we should use Reorder2Default.
+  if (def_format == _desc.data.format) {
+    ptr_->Reorder2Default();
+    return;
+  }
+
+  std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd));
+  ptr_->SetMKLMem(shape_, dtype_);
+  auto old_mem = ptr_->mkl_mem_;
+  // It's possible that the specified layout has a different number of dimensions.
+  if (old_mem->get_primitive_desc().desc().data.ndims != _desc.data.ndims) {
+    // For now, we only support reorder from the default layout.
+    CHECK(ptr_->IsDefault());
+    auto def_pd = GetPrimitiveDesc(pd, def_format);
+    old_mem.reset(new mkldnn::memory(def_pd, old_mem->get_data_handle()));
+  }
+  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
+  std::vector<mkldnn::primitive> net;
+  net.push_back(mkldnn::reorder(*old_mem, *new_mem));
+  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+
+  CHECK(ptr_->shandle.size >= pd.get_size());
+  ptr_->CheckAndAlloc(pd.get_size());
+  // TODO(zhengda) We need to avoid memory copy here.
+  memcpy(ptr_->shandle.dptr, new_mem->get_data_handle(), pd.get_size());
+  ptr_->mkl_mem_.reset(new mkldnn::memory(pd, ptr_->shandle.dptr));
+}
+
+void NDArray::CopyFrom(const mkldnn::memory &mem) {
+  CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
+  if (ptr_->mkl_mem_.get() == &mem)
+    return;
+
+  CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_))
+      << "The size of NDArray doesn't match the requested MKLDNN memory desc";
+  MKLDNNStream *stream = MKLDNNStream::Get();
+  // If this array uses MKLDNN layout and it's a view, we have to change its
+  // layout to the default layout.
+  if (IsMKLDNNData() && IsView())
+    ptr_->Reorder2Default();
+  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
+                  dtype_);
+  stream->RegisterMem(ptr_->mkl_mem_);
+  auto from_desc = mem.get_primitive_desc().desc();
+  auto this_desc = ptr_->mkl_mem_->get_primitive_desc().desc();
+  auto from_def_format = GetDefaultFormat(from_desc);
+  if (IsView()) {
+    // Sliced array must use the default layout.
+    CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
+  }
+  // It's possible that the memory and the NDArray don't have the same shape.
+  if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)
+      // If the source memory uses the default layout, we can reshape directly.
+      && from_def_format == from_desc.data.format) {
+    // In this case, we can simply create a new MKLDNN memory for the required
+    // shape.
+    mkldnn::memory::dims dims(this_desc.data.dims,
+                              this_desc.data.dims + this_desc.data.ndims);
+    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
+    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
+    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
+    mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine());
+    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
+    stream->RegisterMem(tmp_mem);
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+  } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
+    // In this case, the source memory stores data in a customized layout. We
+    // need to reorganize the data in memory before we can reshape.
+    auto def_pd = GetPrimitiveDesc(mem.get_primitive_desc(), from_def_format);
+    auto def_mem = TmpMemMgr::Get()->Alloc(def_pd);
+    stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
+    // Now we can reshape it
+    mkldnn::memory::dims dims(this_desc.data.dims,
+                              this_desc.data.dims + this_desc.data.ndims);
+    auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
+    auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
+    mkldnn::memory::desc data_md(dims, this_dtype, this_format);
+    mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine());
+    mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
+    stream->RegisterMem(tmp_mem);
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->get_primitive_desc()) {
+    // If the layout is the same, we can just copy data.
+    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
+  } else {
+    auto src_def = GetDefaultFormat(mem.get_primitive_desc().desc());
+    auto dst_def = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
+    // If both are not using the default layouts. There isn't much we can do,
+    // other than reorder data layout directly.
+    if (dst_def != ptr_->mkl_mem_->get_primitive_desc().desc().data.format
+        && src_def != mem.get_primitive_desc().desc().data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_));
+    } else if (dst_def == ptr_->mkl_mem_->get_primitive_desc().desc().data.format) {
+      // If the dest mem uses the default memory layout, we can simply use
+      // the default format of the source memory to improve perf of reorder.
+      auto pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), src_def);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->mkl_mem_->get_data_handle()));
+      stream->RegisterMem(tmp_mem);
+      stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
+    } else {
+      // If the src mem uses the default memory layout, we can use
+      // the default format of the source memory to improve perf.
+      auto pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
+      stream->RegisterMem(tmp_mem);
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_));
+    }
+  }
+}
+mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd,
+                                                mkldnn_memory_format_t format);
+
+mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) {
+  // This array shouldn't be a view.
+  CHECK(!IsView());
+
+  if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
+    LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
+    return nullptr;
+  }
+
+  mkldnn::memory::primitive_desc _desc = desc;
+  auto required_format = _desc.desc().data.format;
+  auto def_format = GetDefaultFormat(_desc.desc());
+  // If the required format is a default format, we don't need to worry about the shape.
+  // If the shape isn't the same, it actually implicitly reshapes data.
+  if (required_format == def_format) {
+    ptr_->SetMKLMem(shape_, dtype_);
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
+    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+  }
+
+  if (ptr_->mkl_mem_)
+    CHECK(ptr_->mkl_mem_->get_data_handle() == ptr_->shandle.dptr);
+  if (ptr_->mkl_mem_ && ptr_->mkl_mem_->get_primitive_desc() == desc) {
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
+    return GetMKLDNNExact(ptr_->mkl_mem_.get(), desc);
+  }
+
+  CHECK(ptr_->shandle.size >= desc.get_size());
+  ptr_->CheckAndAlloc(desc.get_size());
+  ptr_->mkl_mem_.reset(new mkldnn::memory(desc, ptr_->shandle.dptr));
+  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_);
+  return ptr_->mkl_mem_.get();
+}
+#endif
+
+void NDArray::SetTBlob() const {
+  CHECK(ptr_ != nullptr);
+  TShape shape = shape_;
+  char *dptr = static_cast<char*>(ptr_->shandle.dptr);
+  auto stype = storage_type();
+  if (stype == kDefaultStorage) {
+#if MXNET_USE_MKLDNN == 1
+    if (IsMKLDNNData()) {
+      ptr_->Reorder2Default();
+      dptr = static_cast<char*>(ptr_->shandle.dptr);
+    }
+#endif
+    dptr += byte_offset_;
+  } else if (stype == kCSRStorage || stype == kRowSparseStorage) {
+    CHECK_EQ(byte_offset_, 0);
+    shape = storage_shape();
+  } else {
+    LOG(FATAL) << "unknown storage type " << stype;
+  }
+  tblob_.dptr_ = dptr;
+  tblob_.shape_ = shape;
+  tblob_.type_flag_ = dtype_;
+  tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id);
+}
 
 /*!
 * \brief run a ternary operation
@@ -449,11 +984,51 @@ inline void CopyFromToRspImpl(const NDArray& from, const NDArray& to, RunContext
 // Make a copy of a dense NDArray
 template<typename from_xpu, typename to_xpu>
 inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext ctx) {
-  using namespace mshadow;
-  CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type";
-  TBlob tmp = to.data();
-  ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp,
-                                  from.ctx(), to.ctx(), ctx);
+#if MXNET_USE_MKLDNN == 1
+  // If neither is MKLDNN, we can copy data normally.
+  if (!from.IsMKLDNNData() && !to.IsMKLDNNData()) {
+#endif
+    using namespace mshadow;
+    CHECK_EQ(from.storage_type(), to.storage_type()) << "Copying with different storage type";
+    TBlob tmp = to.data();
+    ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp,
+                                    from.ctx(), to.ctx(), ctx);
+#if MXNET_USE_MKLDNN == 1
+  } else if (SupportMKLDNN(from.dtype(), from.shape())
+             && SupportMKLDNN(to.dtype(), to.shape())
+             && from.ctx().dev_mask() == cpu::kDevMask
+             && to.ctx().dev_mask() == cpu::kDevMask) {
+    // If we copy data directly, we need to make sure both NDArrays are supported
+    // by MKLDNN.
+    auto from_mem = from.GetMKLDNNData();
+    auto to_mem = to.GetMKLDNNData();
+    if (from_mem->get_primitive_desc() == to_mem->get_primitive_desc()) {
+      size_t size = std::min(from_mem->get_primitive_desc().get_size(),
+                             to_mem->get_primitive_desc().get_size());
+      memcpy(to_mem->get_data_handle(), from_mem->get_data_handle(), size);
+    } else {
+      std::vector<mkldnn::primitive> net;
+      net.push_back(mkldnn::reorder(*from_mem, *to_mem));
+      mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+    }
+  } else {
+    // In this case, one of the NDArray isn't supported by MKLDNN, we need
+    // to convert the MKLDNN array to the default format first and copy data
+    // with Copy().
+    NDArray tmp_from = from;
+    if (tmp_from.IsMKLDNNData()) {
+      tmp_from = NDArray(from.shape(), from.ctx(), false, from.dtype());
+      auto tmp_mem = from.GetMKLDNNData();
+      tmp_from.CopyFrom(*tmp_mem);
+      MKLDNNStream::Get()->Submit();
+    }
+    CHECK(tmp_from.IsDefaultData());
+    CHECK(to.IsDefaultData());
+    TBlob tmp = to.data();
+    ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp,
+                                    from.ctx(), to.ctx(), ctx);
+  }
+#endif
 }
 
 // Make a copy of an NDArray based on storage type
diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h
deleted file mode 100644
index 4225ddf..0000000
--- a/src/operator/concat-inl.h
+++ /dev/null
@@ -1,264 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015 by Contributors
- * \file concat-inl.h
- * \brief
- * \author Bing Xu
-*/
-#ifndef MXNET_OPERATOR_CONCAT_INL_H_
-#define MXNET_OPERATOR_CONCAT_INL_H_
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <cstring>
-#include <map>
-#include <string>
-#include <vector>
-#include <utility>
-#include "./operator_common.h"
-#include "./channel_op_common.h"
-#include "./tensor/broadcast_reduce_op.h"
-
-namespace mxnet {
-namespace op {
-
-namespace concat_enum {
-enum ConcatOpInputs {kData0, kData1, kData2, kData3, kData4};
-enum ConcatOpOutputs {kOut};
-}  // namespace concat_enum
-
-struct ConcatParam : public dmlc::Parameter<ConcatParam> {
-  int num_args;
-  int dim;
-  DMLC_DECLARE_PARAMETER(ConcatParam) {
-    DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
-    .describe("Number of inputs to be concated.");
-    DMLC_DECLARE_FIELD(dim).set_default(1)
-    .describe("the dimension to be concated.");
-  }
-};  // struct ConcatParam
-
-template<typename xpu, typename DType>
-class ConcatOp : public Operator {
- public:
-  explicit ConcatOp(ConcatParam param)
-    : size_(param.num_args), dimension_(param.dim) {}
-
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(static_cast<int>(in_data.size()), size_);
-    CHECK_EQ(out_data.size(), 1U);
-    int axis = CheckAxis(dimension_, in_data[concat_enum::kData0].ndim());
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    std::vector<Tensor<xpu, 3, DType> > data(size_);
-    Tensor<xpu, 3, DType> out;
-    size_t leading = 1, trailing = 1;
-    for (int i = 0; i < axis; ++i) {
-      leading *= out_data[concat_enum::kOut].shape_[i];
-    }
-    for (int i = axis + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
-      trailing *= out_data[concat_enum::kOut].shape_[i];
-    }
-    size_t mid = out_data[concat_enum::kOut].shape_[axis];
-    Shape<3> oshape = Shape3(leading, mid, trailing);
-    out = out_data[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
-
-    for (int i = 0; i < size_; ++i) {
-      Shape<3> dshape = Shape3(leading, in_data[i].shape_[axis], trailing);
-      data[i] = in_data[i].get_with_shape<xpu, 3, DType>(dshape, s);
-    }
-    Concatenate(data, &out, 1, req[concat_enum::kOut]);
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1U);
-    CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
-    int axis = CheckAxis(dimension_, out_grad[concat_enum::kData0].ndim());
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    std::vector<Tensor<xpu, 3, DType> > grad_in(size_);
-    Tensor<xpu, 3, DType> grad;
-    size_t leading = 1, trailing = 1;
-    for (int i = 0; i < axis; ++i) {
-      leading *= out_grad[concat_enum::kOut].shape_[i];
-    }
-    for (int i = axis + 1; i < out_grad[concat_enum::kOut].ndim(); ++i) {
-      trailing *= out_grad[concat_enum::kOut].shape_[i];
-    }
-    size_t mid = out_grad[concat_enum::kOut].shape_[axis];
-    Shape<3> oshape = Shape3(leading, mid, trailing);
-    grad = out_grad[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
-
-    for (int i = 0; i < size_; ++i) {
-      Shape<3> dshape = Shape3(leading, in_grad[i].shape_[axis], trailing);
-      grad_in[i] = in_grad[i].get_with_shape<xpu, 3, DType>(dshape, s);
-    }
-    Split(grad, &grad_in, 1, req);
-  }
-
- private:
-  int size_;
-  int dimension_;
-};  // class ConcatOp
-
-template<typename xpu>
-Operator *CreateOp(ConcatParam param, int dtype, std::vector<TShape> *in_shape);
-
-#if DMLC_USE_CXX11
-class ConcatProp : public OperatorProperty {
- public:
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  std::vector<std::string> ListArguments() const override {
-    std::vector<std::string> ret;
-    for (int i = 0; i < param_.num_args; ++i) {
-      ret.push_back(std::string("arg") + std::to_string(i));
-    }
-    return ret;
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
-    TShape dshape;
-    index_t size = 0;
-    bool has_zero = false;
-    int axis = -1;
-    for (int i = 0; i < param_.num_args; ++i) {
-      TShape tmp = (*in_shape)[i];
-      if (tmp.ndim()) {
-        axis = CheckAxis(param_.dim, tmp.ndim());
-        has_zero = tmp[axis] == 0 || has_zero;
-        size += tmp[axis];
-        tmp[axis] = 0;
-        shape_assign(&dshape, tmp);
-      }
-    }
-
-    TShape tmp = (*out_shape)[0];
-    if (tmp.ndim()) {
-      axis = CheckAxis(param_.dim, tmp.ndim());
-      tmp[axis] = 0;
-      shape_assign(&dshape, tmp);
-    }
-
-    if (dshape.ndim() == 0) return false;
-
-    for (int i = 0; i < param_.num_args; ++i) {
-      CHECK(shape_assign(&(*in_shape)[i], dshape))
-        << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
-    }
-
-    if (!has_zero) dshape[axis] = size;
-    CHECK(shape_assign(&(*out_shape)[0], dshape))
-      << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
-
-    return dshape.Size() != 0;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    int dtype = -1;
-
-    for (size_t i = 0; i < in_type->size(); ++i) {
-      if (dtype == -1) {
-        dtype = in_type->at(i);
-      } else {
-        CHECK(in_type->at(i) == dtype ||
-              in_type->at(i) == -1) <<
-              "Non-uniform data type in Concat";
-      }
-    }
-
-    if (dtype == -1) {
-      LOG(FATAL) << "Not enough information to infer type in Concat.";
-      return false;
-    }
-
-    size_t nin = this->ListArguments().size();
-    in_type->clear();
-    for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
-
-    size_t naux = this->ListAuxiliaryStates().size();
-    aux_type->clear();
-    for (size_t i = 0; i < naux; ++i) aux_type->push_back(dtype);
-
-    size_t nout = this->ListOutputs().size();
-    out_type->clear();
-    for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype);
-
-    return true;
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new ConcatProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    return "Concat";
-  }
-
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    return out_grad;
-  }
-
-  Operator* CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not implemented";
-    return NULL;
-  }
-
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
- private:
-  ConcatParam param_;
-};  // class ConcatProp
-#endif  // DMLC_USE_CXX11
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_CONCAT_INL_H_
diff --git a/src/operator/concat.cc b/src/operator/concat.cc
deleted file mode 100644
index 4d3c2fa..0000000
--- a/src/operator/concat.cc
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015 by Contributors
- * \file concat.cc
- * \brief
- * \author Bing Xu
-*/
-
-#include "./concat-inl.h"
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "./mkl/mkl_memory-inl.h"
-#include "./mkl/mkl_concat-inl.h"
-#endif  // MXNET_USE_MKL2017
-
-namespace mxnet {
-namespace op {
-template<>
-Operator* CreateOp<cpu>(ConcatParam param, int dtype, std::vector<TShape> *in_shape) {
-  Operator *op = NULL;
-#if MXNET_USE_MKL2017 == 1
-  // MKL supports 4D input tensors only for concat operation
-  // 2D/3D input tensors are reshaped to 4D in mkl_concat-inl.h
-  // hence MKL supports 2D/3D/4D input tensors for concat operation
-  size_t dims = (*in_shape)[0].ndim();
-  bool supportedDim = (dims >= 2 && dims <= 4);
-  if ((1 == param.dim) && supportedDim &&
-    (param.num_args < (dnnResourceMultipleDst - dnnResourceMultipleSrc))) {
-    switch (dtype) {
-      case mshadow::kFloat32:
-      return new MKLConcatOp<cpu, float>(param);
-    case mshadow::kFloat64:
-      return new MKLConcatOp<cpu, double>(param);
-    default:
-      break;
-    }
-  }
-  if (enableMKLWarnGenerated())
-    LOG(INFO) << MKLConcatOp<cpu, float>::getName() << " Skip MKL optimization";
-#endif
-  MSHADOW_TYPE_SWITCH(dtype, DType, {
-    op = new ConcatOp<cpu, DType>(param);
-  });
-  return op;
-}
-
-Operator* ConcatProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                                       std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0), in_shape);
-}
-
-DMLC_REGISTER_PARAMETER(ConcatParam);
-
-MXNET_REGISTER_OP_PROPERTY(Concat, ConcatProp)
-.describe(R"code(Joins input arrays along a given axis.
-
-.. note:: `Concat` is deprecated. Use `concat` instead.
-
-The dimensions of the input arrays should be the same except the axis along
-which they will be concatenated.
-The dimension of the output array along the concatenated axis will be equal
-to the sum of the corresponding dimensions of the input arrays.
-
-Example::
-
-   x = [[1,1],[2,2]]
-   y = [[3,3],[4,4],[5,5]]
-   z = [[6,6], [7,7],[8,8]]
-
-   concat(x,y,z,dim=0) = [[ 1.,  1.],
-                          [ 2.,  2.],
-                          [ 3.,  3.],
-                          [ 4.,  4.],
-                          [ 5.,  5.],
-                          [ 6.,  6.],
-                          [ 7.,  7.],
-                          [ 8.,  8.]]
-
-   Note that you cannot concat x,y,z along dimension 1 since dimension
-   0 is not the same for all the input arrays.
-
-   concat(y,z,dim=1) = [[ 3.,  3.,  6.,  6.],
-                         [ 4.,  4.,  7.,  7.],
-                         [ 5.,  5.,  8.,  8.]]
-
-)code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
-.add_arguments(ConcatParam::__FIELDS__())
-.set_key_var_num_args("num_args");
-
-NNVM_REGISTER_OP(Concat).add_alias("concat");
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/convolution_v1.cc b/src/operator/convolution_v1.cc
index 7de6a34..86c0fbb 100644
--- a/src/operator/convolution_v1.cc
+++ b/src/operator/convolution_v1.cc
@@ -25,11 +25,6 @@
 */
 
 #include "./convolution_v1-inl.h"
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "./mkl/mkl_memory-inl.h"
-#include "./mkl/mkl_convolution-inl.h"
-#endif  // MXNET_USE_MKL2017
 #if MXNET_USE_NNPACK == 1
 #include "./nnpack/nnpack_convolution-inl.h"
 #endif  // MXNET_USE_NNPACK
diff --git a/src/operator/lrn-inl.h b/src/operator/lrn-inl.h
deleted file mode 100644
index adfe467..0000000
--- a/src/operator/lrn-inl.h
+++ /dev/null
@@ -1,215 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015 by Contributors
- * \file lrn-inl.h
- * \brief
- * \author Bing Xu
-*/
-#ifndef MXNET_OPERATOR_LRN_INL_H_
-#define MXNET_OPERATOR_LRN_INL_H_
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "./operator_common.h"
-#include "./mshadow_op.h"
-
-namespace mxnet {
-namespace op {
-
-namespace lrn_enum {
-enum LRNInputs {kData};
-enum LRNOutputs {kOut, kTmpNorm};
-}  // namespace lrn_enum
-
-struct LRNParam : public dmlc::Parameter<LRNParam> {
-  float alpha;
-  float beta;
-  float knorm;
-  uint32_t nsize;
-  DMLC_DECLARE_PARAMETER(LRNParam) {
-    DMLC_DECLARE_FIELD(alpha).set_default(1e-4f)
-    .describe("The variance scaling parameter :math:`\alpha` in the LRN expression.");
-    DMLC_DECLARE_FIELD(beta).set_default(0.75f)
-    .describe("The power parameter :math:`\beta` in the LRN expression.");
-    DMLC_DECLARE_FIELD(knorm).set_default(2.0f)
-    .describe("The parameter :math:`k` in the LRN expression.");
-    DMLC_DECLARE_FIELD(nsize)
-    .describe("normalization window width in elements.");
-  }
-};  // struct LRNParam
-
-template<typename xpu>
-class LocalResponseNormOp : public Operator {
- public:
-  explicit LocalResponseNormOp(LRNParam param) {
-    param_ = param;
-  }
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    // TODO(xxx): Test with gradient chceker
-    CHECK_EQ(in_data.size(), 1U);
-    CHECK_EQ(out_data.size(), 2U);
-    // CHECK_EQ(req.size(), 2);
-    CHECK_EQ(param_.nsize % 2, 1U) << "LRN only supports odd values for local_size";
-    const real_t salpha = param_.alpha / param_.nsize;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4> data = in_data[lrn_enum::kData].get<xpu, 4, real_t>(s);
-    Tensor<xpu, 4> out = out_data[lrn_enum::kOut].get<xpu, 4, real_t>(s);
-    Tensor<xpu, 4> tmp_norm = out_data[lrn_enum::kTmpNorm].get<xpu, 4, real_t>(s);
-    tmp_norm = chpool<red::sum>(F<mshadow_op::square>(data) , param_.nsize) * salpha + param_.knorm;
-    Assign(out, req[lrn_enum::kOut], data *  F<mshadow_op::power>(tmp_norm, -param_.beta));
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1U);
-    CHECK_EQ(in_data.size(), 1U);
-    CHECK_EQ(out_data.size(), 2U);
-    const real_t salpha = param_.alpha / param_.nsize;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4> grad = out_grad[lrn_enum::kOut].get<xpu, 4, real_t>(s);
-    Tensor<xpu, 4> tmp_norm = out_data[lrn_enum::kTmpNorm].get<xpu, 4, real_t>(s);
-    Tensor<xpu, 4> data = in_data[lrn_enum::kData].get<xpu, 4, real_t>(s);
-    Tensor<xpu, 4> grad_in = in_grad[lrn_enum::kData].get<xpu, 4, real_t>(s);
-    grad_in = grad * F<mshadow_op::power>(tmp_norm, -param_.beta);
-    grad_in += (- 2.0f * param_.beta * salpha) *
-               chpool<red::sum>(grad * data *
-                                F<mshadow_op::power>(tmp_norm, -param_.beta - 1.0f),
-                                param_.nsize)  * data;
-  }
-
- private:
-  LRNParam param_;
-};  // class LocalResponseNormOp
-
-template<typename xpu>
-Operator *CreateOp(LRNParam param, int dtype);
-
-#if DMLC_USE_CXX11
-class LocalResponseNormProp : public OperatorProperty {
- public:
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
-    const TShape &dshape = in_shape->at(0);
-    if (dshape.ndim() == 0) return false;
-    out_shape->clear();
-    out_shape->push_back(dshape);
-    out_shape->push_back(dshape);
-    return true;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    CHECK_GE(in_type->size(), 1U);
-    int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-        (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
-      }
-    }
-    int n_out = this->ListOutputs().size();
-    out_type->clear();
-    for (int i = 0; i < n_out; ++i ) out_type->push_back(dtype);
-    return true;
-  }
-
-  OperatorProperty* Copy() const override {
-    auto ptr = new LocalResponseNormProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
-
-  std::string TypeString() const override {
-    return "LRN";
-  }
-
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    return {
-      out_grad[lrn_enum::kOut], in_data[lrn_enum::kData],
-      out_data[lrn_enum::kTmpNorm], out_data[lrn_enum::kOut]
-    };
-  }
-
-  int NumVisibleOutputs() const override {
-    return 1;
-  }
-
-  int NumOutputs() const override {
-    return 2;
-  }
-
-  std::vector<std::string> ListArguments() const override {
-    return {"data"};
-  }
-
-  std::vector<std::string> ListOutputs() const override {
-    return {"output", "tmp_norm"};
-  }
-
-  Operator* CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented.";
-    return NULL;
-  }
-
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
- private:
-  LRNParam param_;
-};  // LocalResponseNormProp
-#endif  // DMLC_USE_CXX11
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_LRN_INL_H_
diff --git a/src/operator/lrn.cc b/src/operator/lrn.cc
deleted file mode 100644
index 9b3afd8..0000000
--- a/src/operator/lrn.cc
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015 by Contributors
- * \file lrn.cc
- * \brief
- * \author Bing Xu
-*/
-
-#include "./lrn-inl.h"
-#if MXNET_USE_CUDNN == 1
-#include "./cudnn_lrn-inl.h"
-#endif
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "./mkl/mkl_memory-inl.h"
-#include "./mkl/mkl_lrn-inl.h"
-#endif
-
-namespace mxnet {
-namespace op {
-template<>
-Operator* CreateOp<cpu>(LRNParam param, int dtype) {
-#if MXNET_USE_MKL2017 == 1
-  return new MKLLRNOp<cpu, float>(param);
-#endif
-  return new LocalResponseNormOp<cpu>(param);
-}
-
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator* LocalResponseNormProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-    std::vector<int> *in_type) const {
-    DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
-}
-
-DMLC_REGISTER_PARAMETER(LRNParam);
-
-MXNET_REGISTER_OP_PROPERTY(LRN, LocalResponseNormProp)
-.add_argument("data", "NDArray-or-Symbol", "Input data.")
-.add_arguments(LRNParam::__FIELDS__())
-.describe(R"code(Applies local response normalization to the input.
-
-The local response normalization layer performs "lateral inhibition" by normalizing
-over local input regions.
-
-If :math:`a_{x,y}^{i}` is the activity of a neuron computed by applying kernel :math:`i` at position
-:math:`(x, y)` and then applying the ReLU nonlinearity, the response-normalized
-activity :math:`b_{x,y}^{i}` is given by the expression:
-
-.. math::
-   b_{x,y}^{i} = \frac{a_{x,y}^{i}}{\Bigg({k + \alpha \sum_{j=max(0, i-\frac{n}{2})}^{min(N-1, i+\frac{n}{2})} (a_{x,y}^{j})^{2}}\Bigg)^{\beta}}
-
-where the sum runs over :math:`n` "adjacent" kernel maps at the same spatial position, and :math:`N` is the total
-number of kernels in the layer.
-
-)code" ADD_FILELINE);
-
-}  // namespace op
-}  // namespace mxnet
diff --git a/src/operator/lrn.cu b/src/operator/lrn.cu
deleted file mode 100644
index ba872f1..0000000
--- a/src/operator/lrn.cu
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * Copyright (c) 2015 by Contributors
- * \file lrn.cu
- * \brief
- * \author Bing Xu
-*/
-
-#include "./lrn-inl.h"
-#if MXNET_USE_CUDNN == 1
-#include "./cudnn_lrn-inl.h"
-#endif
-
-namespace mxnet {
-namespace op {
-template<>
-Operator* CreateOp<gpu>(LRNParam param, int dtype) {
-  Operator *op = NULL;
-#if MXNET_USE_CUDNN == 1
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new CuDNNLocalResponseNormOp<DType>(param);
-  })
-#else
-#if CUDA_VERSION == 7000
-  LOG(FATAL) << "Due to old CUDA compiler bug, LRN is disabled."
-             << "Please upgrade CUDA to 7.5+ or use CUDNN";
-#else
-  op = new LocalResponseNormOp<gpu>(param);
-#endif  // CUDA_VERSION
-#endif  // MXNET_USE_CUDNN
-  return op;
-}
-
-}  // namespace op
-}  // namespace mxnet
-
-
diff --git a/src/operator/mkl/mkl_batch_norm-inl.h b/src/operator/mkl/mkl_batch_norm-inl.h
deleted file mode 100644
index b5967f4..0000000
--- a/src/operator/mkl/mkl_batch_norm-inl.h
+++ /dev/null
@@ -1,391 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_batch_norm-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_
-#include <mxnet/storage.h>
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../operator_common.h"
-#include "../mshadow_op.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-class MKLBatchNormOp : public Operator {
- public:
-  explicit MKLBatchNormOp(BatchNormParam param) {
-    this->param_ = param;
-    fwd_top_data = MKLData<DType>::create();
-    fwd_bottom_data = MKLData<DType>::create();
-    bwd_top_diff = MKLData<DType>::create();
-    bwd_bottom_diff = MKLData<DType>::create();
-    scaleShift_space.dptr = NULL;
-    scaleShiftDiff_space.dptr = NULL;
-  }
-  virtual ~MKLBatchNormOp() {
-    if (batchNormFwdInference != NULL) dnnDelete<DType>(batchNormFwdInference);
-    if (batchNormFwdTraining != NULL) dnnDelete<DType>(batchNormFwdTraining);
-    if (batchNormBwdScaleShift != NULL) dnnDelete<DType>(batchNormBwdScaleShift);
-    dnnLayoutDelete<DType>(layout_usr_);
-    if (scaleShift_space.dptr)
-      Storage::Get()->Free(scaleShift_space);
-    if (scaleShiftDiff_space.dptr)
-      Storage::Get()->Free(scaleShiftDiff_space);
-  }
-  static std::string getName() {
-    return "MKLBatchNormOp";
-  }
-
- private:
-  void LayerSetUp(const mshadow::Tensor<xpu, 4, DType> &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out) {
-    eps_ = param_.eps;
-    size_t dim = 4, sizes[4], strides[4];
-    channels_ = data.shape_[1];
-    height_ = data.shape_[2];
-    width_ = data.shape_[3];
-    num_ = data.shape_[0];
-
-    sizes[0] = width_;
-    sizes[1] = height_;
-    sizes[2] = channels_;
-    sizes[3] = num_;
-
-    strides[0] = 1;
-    strides[1] = sizes[0];
-    strides[2] = sizes[0] * sizes[1];
-    strides[3] = sizes[0] * sizes[1] * sizes[2];
-
-    // Names are for debugging only
-    fwd_bottom_data->name = "fwd_bottom_data   @ " + getName();
-    fwd_top_data->name = "fwd_top_data      @ " + getName();
-    bwd_bottom_diff->name = "bwd_bottom_diff   @ " + getName();
-    bwd_top_diff->name = "bwd_top_diff      @ " + getName();
-
-    dnnError_t e;
-    e = dnnLayoutCreate<DType>(&layout_usr_, dim, sizes, strides);
-    CHECK_EQ(e, E_SUCCESS);
-
-    fwd_bottom_data->create_user_layout(dim, sizes, strides);
-    fwd_top_data->create_user_layout(dim, sizes, strides);
-    bwd_bottom_diff->create_user_layout(dim, sizes, strides);
-    bwd_top_diff->create_user_layout(dim, sizes, strides);
-
-    // Primitives will be allocated during the first fwd pass
-    batchNormFwdInference = NULL;
-    batchNormFwdTraining = NULL;
-    batchNormBwdScaleShift = NULL;
-    int scaleShift_size = channels_*2*sizeof(DType);
-    scaleShift_space = Storage::Get()->Alloc(scaleShift_size, Context::CPU());
-    scaleShiftDiff_space = Storage::Get()->Alloc(scaleShift_size, Context::CPU());
-    DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
-    /*!use_weight_bias_*/
-    for (int i = 0; i < channels_; i++) {
-        scaleShift_buf[i] = 1.0;
-        scaleShift_buf[channels_ + i] = 0;
-    }
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 3);
-    CHECK_EQ(aux_states.size(), 2);
-    if (ctx.is_train) {
-      CHECK_EQ(out_data.size(), 3);
-      CHECK_EQ(req.size(), 3);
-    } else {
-      CHECK_GE(out_data.size(), 1);
-      CHECK_GE(req.size(), 1);
-      CHECK_EQ(req[batchnorm::kOut], kWriteTo);
-    }
-
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType>  data;
-    Tensor<xpu, 4, DType>  out;
-    if (in_data[batchnorm::kData].ndim() == 2) {
-      Shape<4> dshape = Shape4(in_data[batchnorm::kData].shape_[0],
-                               in_data[batchnorm::kData].shape_[1], 1, 1);
-      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_data[batchnorm::kData], dshape, s);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[batchnorm::kOut], dshape, s);
-    } else {
-      data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
-      out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[batchnorm::kOut], s);
-    }
-
-    // const real_t scale = static_cast<real_t>(in_data[batchnorm::kData].shape_[1]) /
-    //   static_cast<real_t>(in_data[batchnorm::kData].shape_.Size());
-
-    Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> bias = in_data[batchnorm::kBeta].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);
-
-    if (param_.fix_gamma)
-      slope = 1.f;
-
-    dnnError_t e;
-    if (!init_mkldnn_) {
-      LayerSetUp(data, out);
-      init_mkldnn_ = true;
-    }
-    void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data =
-          reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
-#endif
-    int bwd_flags = dnnUseScaleShift;
-    if (param_.use_global_stats)
-      bwd_flags = dnnUseScaleShift | dnnUseInputMeanVariance;
-#if MKL_EXPERIMENTAL == 1
-    if (NULL != bottom_data) {
-      // Is it the first pass? Create a primitive.
-      if (batchNormFwdInference == NULL) {
-        std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[batchnorm::kData].Mkl_mem_;
-        std::shared_ptr<PrvMemDescr> bottom_prv_desc = bottom_data_mem->get_prv_descriptor();
-        CHECK(bottom_prv_desc->get_descr_type() == PrvMemDescr::PRV_DESCR_MKL2017);
-        std::shared_ptr<MKLData<DType> > mem_descr
-          = std::static_pointer_cast<MKLData<DType>>(bottom_prv_desc);
-        CHECK(mem_descr != NULL);
-        fwd_bottom_data = mem_descr;
-
-        e = dnnBatchNormalizationCreateForward_v2<DType>(
-             &batchNormFwdInference, NULL, mem_descr->layout_int, eps_,
-             dnnUseInputMeanVariance | dnnUseScaleShift);
-        CHECK_EQ(e, E_SUCCESS);
-
-        e = dnnBatchNormalizationCreateForward_v2<DType>(
-              &batchNormFwdTraining, NULL, mem_descr->layout_int, eps_,
-              dnnUseScaleShift);
-        CHECK_EQ(e, E_SUCCESS);
-
-        fwd_top_data->create_internal_layout(batchNormFwdInference, dnnResourceDst);
-        bwd_top_diff->create_internal_layout(batchNormFwdInference, dnnResourceDst);
-        bwd_bottom_diff->create_internal_layout(batchNormFwdInference, dnnResourceSrc);
-
-        e = dnnBatchNormalizationCreateBackward_v2<DType>(
-                &batchNormBwdScaleShift, NULL, mem_descr->layout_int, eps_, bwd_flags);
-        CHECK_EQ(e, E_SUCCESS);
-      }
-    }
-#endif
-    if (NULL == bottom_data) {
-      if (batchNormFwdInference == NULL) {
-        e = dnnBatchNormalizationCreateForward_v2<DType>(
-          &batchNormFwdInference, NULL, layout_usr_, eps_,
-          dnnUseInputMeanVariance | dnnUseScaleShift);
-        CHECK_EQ(e, E_SUCCESS);
-
-        e = dnnBatchNormalizationCreateForward_v2<DType>(
-              &batchNormFwdTraining, NULL, layout_usr_, eps_, dnnUseScaleShift);
-        CHECK_EQ(e, E_SUCCESS);
-
-        e = dnnBatchNormalizationCreateBackward_v2<DType>(
-              &batchNormBwdScaleShift, NULL, layout_usr_, eps_, bwd_flags);
-        CHECK_EQ(e, E_SUCCESS);
-      }
-      bottom_data = reinterpret_cast<void *>(data.dptr_);
-    }
-
-    DType * scaleShift_buf = reinterpret_cast<DType*>(scaleShift_space.dptr);
-     // use_weight_bias_
-    for (int i = 0; i < channels_; i++) {
-        scaleShift_buf[i] = (slope.dptr_)[i];
-    }
-    for (int i = 0; i < channels_; i++) {
-      scaleShift_buf[channels_ + i] = (bias.dptr_)[i];
-    }
-
-    void* BatchNorm_res[dnnResourceNumber];
-    BatchNorm_res[dnnResourceSrc] = bottom_data;
-    BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;
-
-    BatchNorm_res[dnnResourceDst] = fwd_top_data->get_output_ptr(out.dptr_,
-      fwd_top_data, out_data[batchnorm::kOut]);
-    if (ctx.is_train && !param_.use_global_stats) {
-      Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
-      Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
-      CHECK(req[batchnorm::kMean] == kNullOp || req[batchnorm::kMean] == kWriteTo);
-      CHECK(req[batchnorm::kVar] == kNullOp || req[batchnorm::kVar] == kWriteTo);
-      BatchNorm_res[dnnResourceMean] = mean.dptr_;
-      BatchNorm_res[dnnResourceVariance] = var.dptr_;
-      e = dnnExecute<DType>(batchNormFwdTraining, BatchNorm_res);
-      CHECK_EQ(e, E_SUCCESS);
-    } else {
-      BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
-      BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
-      e = dnnExecute<DType>(batchNormFwdInference, BatchNorm_res);
-      CHECK_EQ(e, E_SUCCESS);
-    }
-
-#if MKL_EXPERIMENTAL == 0
-    if (fwd_top_data->conversion_needed()) {
-      fwd_top_data->convert_from_prv(out.dptr_);
-    }
-#endif
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1);
-    CHECK_EQ(in_data.size(), 3);
-    CHECK_EQ(out_data.size(), 3);
-    CHECK_EQ(in_grad.size(), 3);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data, grad, grad_in;
-
-    if (in_data[batchnorm::kData].ndim() == 2) {
-      Shape<4> dshape = Shape4(out_grad[batchnorm::kOut].shape_[0],
-                               out_grad[batchnorm::kOut].shape_[1], 1, 1);
-      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_data[batchnorm::kData], dshape, s);
-      grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[batchnorm::kOut], dshape, s);
-      grad_in = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_grad[batchnorm::kData], dshape, s);
-    } else {
-      data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[batchnorm::kData], s);
-      grad = mkl_experimental_direct_get<xpu, 4, DType>(out_grad[batchnorm::kOut], s);
-      grad_in = mkl_experimental_direct_get<xpu, 4, DType>(in_grad[batchnorm::kData], s);
-    }
-
-    Tensor<xpu, 1, DType> slope = in_data[batchnorm::kGamma].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> gslope = in_grad[batchnorm::kGamma].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> gbias = in_grad[batchnorm::kBeta].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> mean = out_data[batchnorm::kMean].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> var = out_data[batchnorm::kVar].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> moving_mean = aux_states[batchnorm::kMovingMean].get<xpu, 1, DType>(s);
-    Tensor<xpu, 1, DType> moving_var = aux_states[batchnorm::kMovingVar].get<xpu, 1, DType>(s);
-
-    if (param_.fix_gamma)  slope = 1.f;
-
-    void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data = reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[batchnorm::kData]));
-#endif
-    if (NULL == bottom_data)
-      bottom_data = reinterpret_cast<void *>(data.dptr_);
-
-    dnnError_t e;
-    void* BatchNorm_res[dnnResourceNumber];
-    BatchNorm_res[dnnResourceSrc] = bottom_data;
-    BatchNorm_res[dnnResourceScaleShift] = scaleShift_space.dptr;
-    if (ctx.is_train && !param_.use_global_stats) {
-      int size = mean.size(0);  // Tensor<xpu, 1, DType>
-      float * moving_mean_ptr = reinterpret_cast<float*>(moving_mean.dptr_);
-      float * mean_ptr = reinterpret_cast<float*>(mean.dptr_);
-      float * moving_var_ptr = reinterpret_cast<float*>(moving_var.dptr_);
-      float * var_ptr = reinterpret_cast<float*>(var.dptr_);
-      float minus_mom = (1 - param_.momentum);
-      for (int i = 0; i < size; i++) {
-        moving_mean_ptr[i] = moving_mean_ptr[i] * param_.momentum
-          + mean_ptr[i] * minus_mom;
-      }
-      for (int i = 0; i < size; i++) {
-        moving_var_ptr[i] = moving_var_ptr[i] * param_.momentum
-          + var_ptr[i] * minus_mom;
-      }
-      BatchNorm_res[dnnResourceMean] = mean.dptr_;
-      BatchNorm_res[dnnResourceVariance] = var.dptr_;
-    } else {
-      BatchNorm_res[dnnResourceMean] = moving_mean.dptr_;
-      BatchNorm_res[dnnResourceVariance] = moving_var.dptr_;
-    }
-
-
-    BatchNorm_res[dnnResourceDiffSrc] = bwd_bottom_diff->get_output_ptr(grad_in.dptr_,
-      bwd_bottom_diff, in_grad[batchnorm::kData]);
-    BatchNorm_res[dnnResourceDiffDst] = bwd_top_diff->get_converted_prv(grad.dptr_,
-             true, out_grad[batchnorm::kOut]);
-    BatchNorm_res[dnnResourceDiffScaleShift] = scaleShiftDiff_space.dptr;
-    e = dnnExecute<DType>(batchNormBwdScaleShift, BatchNorm_res);
-    CHECK_EQ(e, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-    if (bwd_bottom_diff->conversion_needed()) {
-      bwd_bottom_diff->convert_from_prv(grad_in.dptr_);
-    }
-#endif
-    DType * scaleShiftDiff_buf = reinterpret_cast<DType*>(scaleShiftDiff_space.dptr);
-    if (!param_.fix_gamma) {
-      // Store ScaleShift blobs
-      DType* diff_scale = gslope.dptr_;
-      for (int i = 0; i < channels_; i++) {
-        diff_scale[i] = scaleShiftDiff_buf[i];
-      }
-    } else {
-      int gslope_size = gslope.size(0);
-      float * gslope_ptr = reinterpret_cast<float*>(gslope.dptr_);
-      for (int i = 0; i < gslope_size; i++) {
-        *gslope_ptr++ = 0.0f;
-      }
-    }
-    DType* diff_shift = gbias.dptr_;
-    for (int i = 0; i < channels_; i++) {
-      diff_shift[i] = scaleShiftDiff_buf[channels_ + i];
-    }
-  }
-
- private:
-  BatchNormParam param_;
-  DType eps_;
-  bool use_weight_bias_;
-
-  int num_;
-  int channels_;
-  int height_;
-  int width_;
-  bool init_mkldnn_ = false;
-  std::shared_ptr<MKLData<DType> > fwd_top_data;
-  std::shared_ptr<MKLData<DType> > fwd_bottom_data;
-  std::shared_ptr<MKLData<DType> > bwd_top_diff;
-  std::shared_ptr<MKLData<DType> > bwd_bottom_diff;
-  dnnPrimitive_t batchNormFwdInference = NULL;
-  dnnPrimitive_t batchNormFwdTraining = NULL;
-  dnnPrimitive_t batchNormBwdScaleShift = NULL;
-  Storage::Handle scaleShift_space;
-  Storage::Handle scaleShiftDiff_space;
-  dnnLayout_t layout_usr_ = NULL;
-};  // class BatchNormOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_BATCH_NORM_INL_H_
diff --git a/src/operator/mkl/mkl_concat-inl.h b/src/operator/mkl/mkl_concat-inl.h
deleted file mode 100644
index 1ed1e81..0000000
--- a/src/operator/mkl/mkl_concat-inl.h
+++ /dev/null
@@ -1,314 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_concat-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <cstring>
-#include <map>
-#include <string>
-#include <vector>
-#include <utility>
-#include "../operator_common.h"
-#include "../channel_op_common.h"
-#include "./mkl_util-inl.h"
-namespace mxnet {
-namespace op {
-
-
-template<typename xpu, typename DType>
-class MKLConcatOp : public Operator {
- public:
-  static std::string getName() {
-    return "MKLConcatOp";
-  }
-  explicit MKLConcatOp(ConcatParam param)
-    : size_(param.num_args), dimension_(param.dim), init_mkldnn_(false) {
-    concatFwd_ = static_cast<dnnPrimitive_t>(NULL);
-    concatBwd_ = static_cast<dnnPrimitive_t>(NULL);
-    fwd_top_data_ = MKLData<DType>::create();
-    bwd_top_diff_ = MKLData<DType>::create();
-
-    num_concats_ = param.num_args;
-  }
-  virtual ~MKLConcatOp() {
-    dnnDelete<DType>(concatFwd_);
-    dnnDelete<DType>(concatBwd_);
-  }
-
- private:
-  void LayerSetUp(const std::vector<mshadow::Tensor<xpu, 4, DType> > &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out,
-                  size_t data_shape_size, size_t *split_channels_) {
-    size_t dim_src = data_shape_size;
-    size_t dim_dst = dim_src;
-    num_concats_ = size_;
-    channels_ = 0;
-
-    for (size_t i = 1; i < num_concats_; ++i) {
-      for (size_t j = 1; j < data_shape_size; ++j) {
-        if (j == dimension_) continue;
-        CHECK_EQ(data[0].shape_[j], data[i].shape_[j]);
-      }
-    }
-
-    for (size_t i = 0; i < num_concats_; ++i) {
-      CHECK_EQ((int)dim_src, data[i].shape_.kDimension);
-
-      fwd_bottom_data_.push_back(MKLData<DType>::create());
-      bwd_bottom_diff_.push_back(MKLData<DType>::create());
-      fwd_bottom_data_[i]->name = "fwd_bottom_data_[i]";
-      bwd_bottom_diff_[i]->name = "bwd_bottom_data[i]";
-
-      size_t *sizes_src = new size_t[dim_src];
-      size_t *strides_src = new size_t[dim_src];
-      for (size_t d = 0; d < dim_src; ++d) {
-        sizes_src[d] = data[i].shape_[dim_src - d - 1];
-        strides_src[d] = (d == 0) ? 1 : strides_src[d - 1] * sizes_src[d - 1];
-      }
-
-      split_channels_[i] = data[i].shape_[1];
-      channels_ += split_channels_[i];
-      fwd_bottom_data_[i]->create_user_layout(dim_src, sizes_src, strides_src);
-      bwd_bottom_diff_[i]->create_user_layout(dim_src, sizes_src, strides_src);
-      delete[] sizes_src;
-      delete[] strides_src;
-    }
-    size_t *sizes_dst = new size_t[dim_dst];
-    size_t *strides_dst = new size_t[dim_dst];
-    for (size_t d = 0; d < dim_dst; ++d) {
-      if (d == 2)
-        sizes_dst[d] = channels_;
-      else
-        sizes_dst[d] = data[0].shape_[dim_dst - 1 - d];
-      strides_dst[d] = (d == 0) ? 1 : strides_dst[d - 1] * sizes_dst[d - 1];
-    }
-    bwd_top_diff_->create_user_layout(dim_dst, sizes_dst, strides_dst);
-    fwd_top_data_->create_user_layout(dim_dst, sizes_dst, strides_dst);
-    delete[] sizes_dst;
-    delete[] strides_dst;
-    concatFwd_ = NULL;
-    concatBwd_ = NULL;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(static_cast<int>(in_data.size()), size_);
-    CHECK_EQ(out_data.size(), 1);
-    CHECK_LT(dimension_, (size_t)in_data[concat_enum::kData0].ndim());
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    std::vector<Tensor<xpu, 4, DType> > data(size_);
-    Tensor<xpu, 4, DType> out;
-    if (in_data[0].ndim() == 2) {
-      for (int i = 0; i < size_; ++i) {
-        Shape<4> dshape = Shape4(in_data[i].shape_[0],
-                                 in_data[i].shape_[1], 1, 1);
-        data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-          in_data[i], dshape, s);
-      }
-      Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
-                               out_data[concat_enum::kOut].shape_[1], 1, 1);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[concat_enum::kOut], dshape, s);
-    } else if (in_data[0].ndim() == 3) {
-      for (int i = 0; i < size_; ++i) {
-        Shape<4> dshape = Shape4(in_data[i].shape_[0],
-          in_data[i].shape_[1], in_data[i].shape_[2], 1);
-        data[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-          in_data[i], dshape, s);
-      }
-      Shape<4> dshape = Shape4(out_data[concat_enum::kOut].shape_[0],
-        out_data[concat_enum::kOut].shape_[1],
-        out_data[concat_enum::kOut].shape_[2], 1);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[concat_enum::kOut], dshape, s);
-    } else {
-      for (int i = 0; i < size_; ++i) {
-        data[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_data[i], s);
-      }
-      out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[concat_enum::kOut], s);
-    }
-    size_t *split_channels_ = new size_t[num_concats_];
-    if (!init_mkldnn_) {
-      init_mkldnn_ = true;
-      LayerSetUp(data, out, 4, split_channels_);
-    }
-
-    dnnError_t e;
-    std::vector<void*> bottom_data;
-    bool isFirstPass = (concatFwd_ == NULL);
-    dnnLayout_t *layouts = NULL;
-    if (isFirstPass) {
-      layouts = new dnnLayout_t[num_concats_];
-    }
-
-    for (size_t i = 0; i < num_concats_; i++) {
-      void * bottom_i = NULL;
-#if MKL_EXPERIMENTAL == 1
-      bottom_i = mkl_prv_data<DType>(in_data[i]);
-      if (bottom_i != NULL) {
-        if (isFirstPass) {
-          std::shared_ptr<MKLData<DType> > mem_descr =
-            mkl_get_mem_desc<DType>(in_data[i].Mkl_mem_);
-          fwd_bottom_data_[i] = mem_descr;
-          layouts[i] = mem_descr->layout_int;
-        }
-      }
-#endif
-      if (bottom_i == NULL) {
-        bottom_i = data[i].dptr_;
-        if (isFirstPass) {
-          layouts[i] = fwd_bottom_data_[i]->layout_usr;
-        }
-      }
-
-      bottom_data.push_back(reinterpret_cast<void *>(bottom_i));
-    }
-
-    if (isFirstPass) {
-      e = dnnConcatCreate<DType>(&concatFwd_, NULL, num_concats_, layouts);
-      CHECK_EQ(e, E_SUCCESS);
-
-      fwd_top_data_->create_internal_layout(concatFwd_, dnnResourceDst);
-      bwd_top_diff_->create_internal_layout(concatFwd_, dnnResourceDst);
-
-      e = dnnSplitCreate<DType>(&concatBwd_, NULL, num_concats_,
-            bwd_top_diff_->layout_int, split_channels_);
-      CHECK_EQ(e, E_SUCCESS);
-
-      for (size_t n = 0; n < num_concats_; ++n) {
-        fwd_bottom_data_[n]->create_internal_layout(concatFwd_,
-          (dnnResourceType_t)(dnnResourceMultipleSrc + n));
-        bwd_bottom_diff_[n]->create_internal_layout(concatBwd_,
-          (dnnResourceType_t)(dnnResourceMultipleDst + n));
-      }
-    }
-    delete[] layouts;
-
-    void *concat_res[dnnResourceNumber];
-    for (size_t i = 0; i < num_concats_; ++i) {
-      concat_res[dnnResourceMultipleSrc + i]
-        = reinterpret_cast<void*>(bottom_data[i]);
-    }
-
-    concat_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(out.dptr_,
-      fwd_top_data_, out_data[concat_enum::kOut]);
-    e = dnnExecute<DType>(concatFwd_, concat_res);
-    CHECK_EQ(e, E_SUCCESS);
-    delete[] split_channels_;
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1);
-    CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    std::vector<Tensor<xpu, 4, DType> > grad_in(size_);
-    Tensor<xpu, 4, DType> grad;
-    if (in_grad[0].ndim() == 2) {
-      Shape<4> dshape = Shape4(out_grad[concat_enum::kOut].shape_[0],
-        out_grad[concat_enum::kOut].shape_[1], 1, 1);
-      grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[concat_enum::kOut], dshape, s);
-      for (int i = 0; i < size_; ++i) {
-        dshape = Shape4(in_grad[i].shape_[0],
-          in_grad[i].shape_[1], 1, 1);
-        grad_in[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-          in_grad[i], dshape, s);
-      }
-    } else if (in_grad[0].ndim() == 3) {
-      Shape<4> dshape = Shape4(out_grad[concat_enum::kOut].shape_[0],
-        out_grad[concat_enum::kOut].shape_[1],
-        out_grad[concat_enum::kOut].shape_[2], 1);
-      grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[concat_enum::kOut], dshape, s);
-      for (int i = 0; i < size_; ++i) {
-        dshape = Shape4(in_grad[i].shape_[0],
-          in_grad[i].shape_[1], in_grad[i].shape_[2], 1);
-        grad_in[i] = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-          in_grad[i], dshape, s);
-      }
-    } else {
-      grad = mkl_experimental_direct_get<xpu, 4, DType>(out_grad[concat_enum::kOut], s);
-      for (int i = 0; i < size_; ++i) {
-        grad_in[i] = mkl_experimental_direct_get<xpu, 4, DType>(in_grad[i], s);
-      }
-    }
-
-    int need_bwd = 0;
-    for (size_t n = 0; n < num_concats_; n++) {
-      need_bwd += req[n];
-    }
-    if (!need_bwd) {
-      return;
-    }
-
-    dnnError_t e;
-    void *concat_res[dnnResourceNumber];
-    concat_res[dnnResourceSrc] = bwd_top_diff_->get_converted_prv(grad.dptr_, true,
-      out_grad[concat_enum::kOut]);
-    for (size_t i = 0; i < num_concats_; ++i) {
-      concat_res[dnnResourceMultipleDst + i] = bwd_bottom_diff_[i]->get_output_ptr(
-        grad_in[i].dptr_, bwd_bottom_diff_[i], in_grad[i]);
-    }
-    e = dnnExecute<DType>(concatBwd_, concat_res);
-    CHECK_EQ(e, E_SUCCESS);
-  }
-
- private:
-  int size_;
-  size_t dimension_;
-
-  bool init_mkldnn_;
-
-  dnnPrimitive_t concatFwd_;
-  dnnPrimitive_t concatBwd_;
-  std::shared_ptr<MKLData<DType> > fwd_top_data_;
-  std::vector< std::shared_ptr<MKLData<DType> > > fwd_bottom_data_;
-  std::shared_ptr<MKLData<DType> > bwd_top_diff_;
-  std::vector< std::shared_ptr<MKLData<DType> > > bwd_bottom_diff_;
-
-
-  size_t width_;
-  size_t height_;
-  size_t channels_;
-  size_t num_;
-  size_t num_concats_;
-};  // class MKLConcatOp
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_MKL_MKL_CONCAT_INL_H_
diff --git a/src/operator/mkl/mkl_convolution-inl.h b/src/operator/mkl/mkl_convolution-inl.h
deleted file mode 100644
index 813d061..0000000
--- a/src/operator/mkl/mkl_convolution-inl.h
+++ /dev/null
@@ -1,490 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_convolution-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_CONVOLUTION_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_CONVOLUTION_INL_H_
-#include <mxnet/storage.h>
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../operator_common.h"
-#include "../nn/convolution-inl.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-class MKLConvolutionOp : public Operator {
- public:
-  static std::string getName() {
-    return "MKLConvolutionOp";
-  }
-  void SetupBuffer() {
-    convolutionBwdBias = static_cast<dnnPrimitive_t>(NULL);
-    convolutionBwdFilter = static_cast<dnnPrimitive_t>(NULL);
-    convolutionBwdData = static_cast<dnnPrimitive_t>(NULL);
-    convolutionFwd = static_cast<dnnPrimitive_t>(NULL);
-    fwd_bottom_data = MKLData<DType>::create();
-    fwd_top_data = MKLData<DType>::create();
-    fwd_filter_data = MKLData<DType>::create();
-    fwd_bias_data = MKLData<DType>::create();
-    bwdd_top_diff = MKLData<DType>::create();
-    bwdd_bottom_diff = MKLData<DType>::create();
-    bwdd_filter_data = MKLData<DType>::create();
-    bwdf_top_diff = MKLData<DType>::create();
-    bwdf_filter_diff = MKLData<DType>::create();
-    bwdf_bottom_data = MKLData<DType>::create();
-    bwdb_top_diff = MKLData<DType>::create();
-    bwdb_bias_diff = MKLData<DType>::create();
-    // Names are for debugging purposes only.
-    fwd_bottom_data->name = "fwd_bottom_data   @ " + this->getName();
-    fwd_top_data->name = "fwd_top_data      @ " + this->getName();
-    fwd_filter_data->name = "fwd_filter_data   @ " + this->getName();
-    fwd_bias_data->name = "fwd_bias_data     @ " + this->getName();
-    bwdd_top_diff->name = "bwdd_top_diff     @ " + this->getName();
-    bwdd_bottom_diff->name = "bwdd_bottom_diff  @ " + this->getName();
-    bwdd_filter_data->name = "bwdd_filter_data  @ " + this->getName();
-    bwdf_top_diff->name = "bwdf_top_diff     @ " + this->getName();
-    bwdf_bottom_data->name = "bwdf_bottom_data  @ " + this->getName();
-    bwdf_filter_diff->name = "bwdf_filter_diff  @ " + this->getName();
-    bwdb_top_diff->name = "bwdb_top_diff     @ " + this->getName();
-    bwdb_bias_diff->name = "bwdb_bias_diff    @ " + this->getName();
-  }
-
-  explicit MKLConvolutionOp(ConvolutionParam p):
-                            convolutionFwd(NULL),
-                            convolutionBwdData(static_cast<dnnPrimitive_t>(NULL)),
-                            convolutionBwdFilter(static_cast<dnnPrimitive_t>(NULL)),
-                            convolutionBwdBias(static_cast<dnnPrimitive_t>(NULL)) {
-    this->param_ = p;
-    init_mkldnn_ = false;
-    // convert MBytes first to Bytes and then to elements.
-    param_.workspace = (param_.workspace << 20) / sizeof(DType);
-    SetupBuffer();
-  }
-  void ReleaseBuffer() {
-    if (convolutionFwd != NULL) {
-     dnnDelete<DType>(convolutionFwd);
-     convolutionFwd = NULL;
-    }
-    if (convolutionBwdData != NULL) {
-     dnnDelete<DType>(convolutionBwdData);
-     convolutionBwdData = NULL;
-    }
-    if (convolutionBwdFilter != NULL) {
-     dnnDelete<DType>(convolutionBwdFilter);
-     convolutionBwdFilter = NULL;
-    }
-    if (!param_.no_bias && convolutionBwdBias != NULL) {
-     dnnDelete<DType>(convolutionBwdBias);
-     convolutionBwdBias = NULL;
-    }
-  }
-  virtual ~MKLConvolutionOp() {
-    ReleaseBuffer();
-  }
-
- private:
-  void LayerSetUp(const mshadow::Tensor<xpu, 4, DType> &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out) {
-    this->width_ = data.shape_[3];
-    this->height_ = data.shape_[2];
-    this->channels_ = data.shape_[1];
-    this->num_ = data.shape_[0];
-    this->group_ = param_.num_group;
-    this->width_out_ = out.shape_[3];
-    this->height_out_ = out.shape_[2];
-    int channel_out_ = out.shape_[1];
-    this->num_output_ = channel_out_;
-    kernel_w_ = param_.kernel[1];
-    kernel_h_ = param_.kernel[0];
-    stride_w_ = param_.stride[1];
-    stride_h_ = param_.stride[0];
-    pad_w_ = param_.pad[1];
-    pad_h_ = param_.pad[0];
-    int status;
-    size_t n, g;
-    size_t iw, ih, ic;
-    size_t ow, oh, oc;
-    size_t kw, kh;
-    size_t dimension = 4;
-    g = std::max(this->group_, 1);
-    n = this->num_;
-    iw = this->width_;
-    ih = this->height_;
-    ic = this->channels_;
-    ow = this->width_out_;
-    oh = this->height_out_;
-    oc = this->num_output_;
-    kw = this->kernel_w_;
-    kh = this->kernel_h_;
-    oc = this->num_output_;
-    size_t bdata_sizes[4] = { iw, ih, ic, n };
-    size_t bdata_strides[4] = { 1, iw, iw*ih, iw*ih*ic };
-    /* starting with MKL 2017 Gold in case of groups filter layout
-    * becomes 5D, i.e. groups become a separate dimension */
-    size_t g_mkl2017 = g;
-    size_t f_dimension = dimension + (g != 1);
-    if (getMKLBuildDate() < 20160701) {
-     g_mkl2017 = 1;
-     f_dimension = dimension;
-    }
-    size_t fdata_sizes[5] = { kw, kh, ic / g, oc / g_mkl2017, g_mkl2017 };
-    size_t fdata_strides[5] = { 1, kw, kw*kh, kw*kh*ic / g, kw*kh*ic / g*oc / g };
-    size_t bias_sizes[1] = { oc };
-    size_t bias_strides[1] = { 1 };
-    size_t tdata_sizes[4] = { ow, oh, oc, n };
-    size_t tdata_strides[4] = { 1, ow, ow*oh, ow*oh*oc };
-    size_t convolutionStrides[2] = { this->stride_w_, this->stride_h_ };
-    int    inputOffset[2] = { -this->pad_w_, -this->pad_h_ };
-    // Names are for debugging purposes only.
-    /*** convolution section ***/
-    if (!param_.no_bias) {
-      status = dnnGroupsConvolutionCreateForwardBias<DType>(&convolutionFwd,
-                                                            NULL,
-                                                            dnnAlgorithmConvolutionDirect,
-                                                            g,
-                                                            dimension,
-                                                            bdata_sizes,
-                                                            tdata_sizes,
-                                                            fdata_sizes,
-                                                            convolutionStrides,
-                                                            inputOffset,
-                                                            dnnBorderZeros);
-    } else {
-      status = dnnGroupsConvolutionCreateForward<DType>(&convolutionFwd,
-                                                        NULL,
-                                                        dnnAlgorithmConvolutionDirect,
-                                                        g,
-                                                        dimension,
-                                                        bdata_sizes,
-                                                        tdata_sizes,
-                                                        fdata_sizes,
-                                                        convolutionStrides,
-                                                        inputOffset,
-                                                        dnnBorderZeros);
-    }
-    CHECK_EQ(status, 0)
-     << "Failed dnnCreateConvolution<DType>(dnnForward) with status "
-     << status << "\n";
-    fwd_bottom_data->create_layouts(convolutionFwd, dnnResourceSrc, dimension,
-                                    bdata_sizes, bdata_strides);
-    fwd_top_data->create_layouts(convolutionFwd, dnnResourceDst, dimension,
-                                 tdata_sizes, tdata_strides);
-    fwd_filter_data->create_layouts(convolutionFwd, dnnResourceFilter,
-                                    f_dimension, fdata_sizes, fdata_strides);
-    if (!param_.no_bias)
-      fwd_bias_data->create_layouts(convolutionFwd, dnnResourceBias, 1,
-                                    bias_sizes, bias_strides);
-    /*
-    * Backward by data layer setup
-    */
-    status = dnnGroupsConvolutionCreateBackwardData<DType>(&convolutionBwdData,
-                                                           NULL,
-                                                           dnnAlgorithmConvolutionDirect,
-                                                           g,
-                                                           dimension,
-                                                           bdata_sizes,
-                                                           tdata_sizes,
-                                                           fdata_sizes,
-                                                           convolutionStrides,
-                                                           inputOffset,
-                                                           dnnBorderZeros);
-    CHECK_EQ(status, 0)
-     << "Failed dnnConvolutionCreateBackwardData with status "
-     << status << "\n";
-    bwdd_bottom_diff->create_layouts(convolutionBwdData, dnnResourceDiffSrc,
-                                     dimension, bdata_sizes, bdata_strides);
-    bwdd_top_diff->create_layouts(convolutionBwdData, dnnResourceDiffDst,
-                                  dimension, tdata_sizes, tdata_strides);
-    bwdd_filter_data->create_layouts(convolutionBwdData, dnnResourceFilter,
-                                     f_dimension, fdata_sizes, fdata_strides);
-    /*
-    * Backward by filter layer setup
-    */
-    status = dnnGroupsConvolutionCreateBackwardFilter<DType>(&convolutionBwdFilter,
-                                                             NULL,
-                                                             dnnAlgorithmConvolutionDirect,
-                                                             g,
-                                                             dimension,
-                                                             bdata_sizes,
-                                                             tdata_sizes,
-                                                             fdata_sizes,
-                                                             convolutionStrides,
-                                                             inputOffset,
-                                                             dnnBorderZeros);
-    CHECK_EQ(status, 0)
-     << "Failed dnnConvolutionCreateBackwardFilter with status "
-     << status << "\n";
-    bwdf_bottom_data->create_layouts(convolutionBwdFilter, dnnResourceSrc,
-                                     dimension, bdata_sizes, bdata_strides);
-    bwdf_top_diff->create_layouts(convolutionBwdFilter, dnnResourceDiffDst,
-                                  dimension, tdata_sizes, tdata_strides);
-    bwdf_filter_diff->create_layouts(convolutionBwdFilter, dnnResourceDiffFilter,
-                                     f_dimension, fdata_sizes, fdata_strides);
-    /*
-    * Backward by bias layer setup
-    */
-    if (!param_.no_bias) {
-      status = dnnGroupsConvolutionCreateBackwardBias<DType>(&convolutionBwdBias,
-                                                             NULL,
-                                                             dnnAlgorithmConvolutionDirect,
-                                                             g,
-                                                             dimension,
-                                                             tdata_sizes);
-     CHECK_EQ(status, 0)
-      << "Failed dnnConvolutionCreateBackwardBias with status "
-      << status << "\n";
-     bwdb_top_diff->create_layouts(convolutionBwdBias, dnnResourceDiffDst,
-                                   dimension, tdata_sizes, tdata_strides);
-     bwdb_bias_diff->create_layouts(convolutionBwdBias, dnnResourceDiffBias, 1,
-                                    bias_sizes, bias_strides);
-    }
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    DType *data_ptr = NULL;
-    DType *wmat_ptr = NULL;
-    DType *out_ptr = NULL;
-    Tensor<xpu, 4, DType> data =
-      mkl_experimental_direct_get<xpu, 4, DType>(in_data[conv::kData], s);
-    Tensor<xpu, 4, DType> out =
-      mkl_experimental_direct_get<xpu, 4, DType>(out_data[conv::kOut], s);
-    Tensor<xpu, 4, DType> wmat =
-      mkl_experimental_direct_get<xpu, 4, DType>(in_data[conv::kWeight], s);
-    if (!init_mkldnn_) {
-      LayerSetUp(data, out);
-      init_mkldnn_ = true;
-    }
-    CHECK_EQ(data.CheckContiguous(), true);
-    CHECK_EQ(wmat.CheckContiguous(), true);
-    CHECK_EQ(out.CheckContiguous(), true);
-    data_ptr = data.dptr_;
-    wmat_ptr = wmat.dptr_;
-    out_ptr = out.dptr_;
-    int status;
-    void *res_convolutionFwd[dnnResourceNumber];
-    res_convolutionFwd[dnnResourceSrc] =
-      fwd_bottom_data->get_converted_prv(data_ptr, false, in_data[conv::kData]);
-    res_convolutionFwd[dnnResourceFilter] =
-      fwd_filter_data->get_converted_prv(wmat_ptr, true, in_data[conv::kWeight]);
-    if (!param_.no_bias) {
-      Tensor<xpu, 1, DType> bias =
-        mkl_experimental_direct_get<xpu, 1, DType>(in_data[conv::kBias], s);
-      res_convolutionFwd[dnnResourceBias] =
-        fwd_bias_data->get_converted_prv(bias.dptr_, true, in_data[conv::kBias]);
-    }
-
-    res_convolutionFwd[dnnResourceDst] = fwd_top_data->get_output_ptr(out_ptr,
-      fwd_top_data, out_data[conv::kOut]);
-    status = dnnExecute<DType>(convolutionFwd, res_convolutionFwd);
-    CHECK_EQ(status, 0) << "Forward convolution failed with status " << status;
-#if MKL_EXPERIMENTAL == 0
-    if (fwd_top_data->conversion_needed()) {
-        fwd_top_data->convert_from_prv(out_ptr);
-    }
-#endif
-  }
-  void AddToModeAllocAndStoreBuffer(void *src, int blob_size, Storage::Handle *pws) {
-    int blob_byte_size = blob_size * sizeof(DType);
-    *pws = Storage::Get()->Alloc(blob_byte_size, Context::CPU());
-    memcpy(pws->dptr, src, blob_byte_size);
-  }
-  void AddToModeAddAndReleaseBuffer(Storage::Handle *pws, void *dst_, int blob_size) {
-    DType *dst = reinterpret_cast<DType*>(dst_);
-    DType *src = reinterpret_cast<DType*>(pws->dptr);
-#pragma omp parallel for
-    for (int i = 0; i < blob_size; i++) {
-      dst[i] += src[i];
-    }
-    if (pws->dptr)
-      Storage::Get()->Free(*pws);
-    pws->dptr = NULL;
-  }
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    if (param_.kernel.ndim() > 2) {
-      LOG(FATAL) << "Volume convolution is not implmented in mshadow";
-    }
-    CHECK_EQ(out_grad.size(), 1);
-    size_t expected = param_.no_bias == 0 ? 3 : 2;
-    CHECK(in_data.size() == expected && in_grad.size() == expected);
-    CHECK_EQ(req.size(), expected);
-    CHECK_EQ(in_data[conv::kWeight].CheckContiguous(), true);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data =
-      mkl_experimental_direct_get<xpu, 4, DType>(in_data[conv::kData], s);
-    Shape<3> wmat_shape =
-      Shape3(param_.num_group,
-             param_.num_filter / param_.num_group,
-             data.shape_[1] / param_.num_group * param_.kernel[0] * param_.kernel[1]);
-    Tensor<xpu, 3, DType> wmat =
-      mkl_experimental_direct_get_with_shape<xpu, 3, DType>(
-      in_data[conv::kWeight], wmat_shape, s);
-    Tensor<xpu, 4, DType> grad =
-      mkl_experimental_direct_get<xpu, 4, DType>(out_grad[conv::kOut], s);
-    Tensor<xpu, 4, DType> gdata =
-      mkl_experimental_direct_get<xpu, 4, DType>(in_grad[conv::kData], s);
-    Tensor<xpu, 3, DType> gwmat =
-      mkl_experimental_direct_get_with_shape<xpu, 3, DType>(
-      in_grad[conv::kWeight], wmat_shape, s);
-
-    if (!init_mkldnn_) {
-      init_mkldnn_ = true;
-      LayerSetUp(data, grad);
-    }
-    int status;
-    if (req[0]) {
-      void *res_convolutionBwdData[dnnResourceNumber];
-      res_convolutionBwdData[dnnResourceDiffDst] =
-        bwdd_top_diff->get_converted_prv(grad.dptr_, true, out_grad[conv::kOut]);
-
-      res_convolutionBwdData[dnnResourceFilter] =
-        bwdd_filter_data->get_converted_prv(wmat.dptr_, false, in_data[conv::kWeight]);
-     Storage::Handle addtoWorkspace;
-     if (req[0] == kAddTo) {
-       // wait mkl support addto mode
-       AddToModeAllocAndStoreBuffer(gdata.dptr_, in_grad[conv::kData].Size(), &addtoWorkspace);
-     }
-
-     res_convolutionBwdData[dnnResourceDiffSrc] = bwdd_bottom_diff->get_output_ptr(gdata.dptr_,
-       bwdd_bottom_diff, in_grad[conv::kData]);
-     status = dnnExecute<DType>(convolutionBwdData, res_convolutionBwdData);
-     CHECK_EQ(status, 0) << "Backward Data conv failed with status " << status;
-#if MKL_EXPERIMENTAL == 0
-     if (bwdd_bottom_diff->conversion_needed()) {
-       bwdd_bottom_diff->convert_from_prv(gdata.dptr_);
-     }
-#endif
-     if (req[0] == kAddTo) {
-       if (bwdd_bottom_diff->conversion_needed()) {
-         bwdd_bottom_diff->convert_from_prv(gdata.dptr_);
-       }
-      AddToModeAddAndReleaseBuffer(&addtoWorkspace, gdata.dptr_, in_grad[conv::kData].Size());
-     }
-    }
-    if (req[1]) {
-      void *res_convolutionBwdFilter[dnnResourceNumber];
-
-      res_convolutionBwdFilter[dnnResourceDiffDst] =
-        bwdf_top_diff->get_converted_prv(grad.dptr_, true, out_grad[conv::kOut]);
-
-      res_convolutionBwdFilter[dnnResourceSrc] =
-        bwdf_bottom_data->get_converted_prv(data.dptr_, false,
-          in_data[conv::kData]);
-     Storage::Handle addtoWorkspace;
-     if (req[1] == kAddTo) {
-       // wait mkl support addto mode
-       AddToModeAllocAndStoreBuffer(gwmat.dptr_, in_grad[conv::kWeight].Size(), &addtoWorkspace);
-     }
-
-     res_convolutionBwdFilter[dnnResourceDiffFilter] = bwdf_filter_diff->get_output_ptr(
-       gwmat.dptr_, bwdf_filter_diff, in_grad[conv::kWeight]);
-     status = dnnExecute<DType>(convolutionBwdFilter, res_convolutionBwdFilter);
-     CHECK_EQ(status, 0) << "Backward Filter conv failed with status " << status;
-#if MKL_EXPERIMENTAL == 0
-     if (bwdf_filter_diff->conversion_needed()) {
-       bwdf_filter_diff->convert_from_prv(gwmat.dptr_);
-     }
-#endif
-     if (req[1] == kAddTo) {
-       if (bwdf_filter_diff->conversion_needed()) {
-         bwdf_filter_diff->convert_from_prv(gwmat.dptr_);
-       }
-       AddToModeAddAndReleaseBuffer(&addtoWorkspace, gwmat.dptr_, in_grad[conv::kWeight].Size());
-     }
-    }
-    if (!param_.no_bias) {
-      Tensor<xpu, 1, DType> gbias =
-        mkl_experimental_direct_get<xpu, 1, DType>(in_grad[conv::kBias], s);
-      void *res_convolutionBwdBias[dnnResourceNumber];
-      res_convolutionBwdBias[dnnResourceDiffDst] =
-        bwdb_top_diff->get_converted_prv(grad.dptr_, true, out_grad[conv::kOut]);
-
-      res_convolutionBwdBias[dnnResourceDiffBias] = bwdb_bias_diff->get_output_ptr(gbias.dptr_,
-        bwdb_bias_diff, in_grad[conv::kBias]);
-      status = dnnExecute<DType>(convolutionBwdBias, res_convolutionBwdBias);
-      CHECK_EQ(status, 0) << "Backward Bias failed with status " << status;
-#if MKL_EXPERIMENTAL == 0
-      if (bwdb_bias_diff->conversion_needed()) {
-        bwdb_bias_diff->convert_from_prv(gbias.dptr_);
-      }
-#endif
-    }
-  }
-
- private:
-  ConvolutionParam param_;
-  size_t width_,
-         height_,
-         width_out_,
-         height_out_,
-         kernel_w_,
-         kernel_h_,
-         stride_w_,
-         stride_h_;
-  int group_,
-      num_,
-      num_output_;
-  size_t channels_;
-  int pad_w_,
-      pad_h_;
-  bool init_mkldnn_;
-  dnnPrimitive_t convolutionFwd;
-  dnnPrimitive_t convolutionBwdData;
-  dnnPrimitive_t convolutionBwdFilter;
-  dnnPrimitive_t convolutionBwdBias;
-  /* Fwd step */
-  std::shared_ptr<MKLData<DType> > fwd_bottom_data, fwd_top_data, fwd_filter_data,
-                                   fwd_bias_data;
-  /* Bwd data step */
-  std::shared_ptr<MKLData<DType> > bwdd_top_diff, bwdd_bottom_diff;
-  std::shared_ptr<MKLData<DType> > bwdd_filter_data;
-  /* Bwd filter step */
-  std::shared_ptr<MKLData<DType> > bwdf_top_diff, bwdf_filter_diff;
-  std::shared_ptr<MKLData<DType> > bwdf_bottom_data;
-  std::shared_ptr<MKLData<DType> > bwdf_filter_diff_iter, bwdf2fwd_filter_diff,
-                                   bwdb_bias_diff_iter;
-  /* Bwd bias step */
-  std::shared_ptr<MKLData<DType> > bwdb_top_diff, bwdb_bias_diff;
-};  // class ConvolutionOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_CONVOLUTION_INL_H_
diff --git a/src/operator/mkl/mkl_cppwrapper.cc b/src/operator/mkl/mkl_cppwrapper.cc
deleted file mode 100644
index 507e549..0000000
--- a/src/operator/mkl/mkl_cppwrapper.cc
+++ /dev/null
@@ -1,44 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_cppwrapper.cc
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-
-
-
-#include "mkl_cppwrapper.h"
-#include <stdio.h>
-#if MXNET_USE_MKL2017 == 1
-#include "mkl_service.h"
-
-int getMKLBuildDate() {
-    static int build = 0;
-    if (build == 0) {
-        MKLVersion v;
-        mkl_get_version(&v);
-        build = atoi(v.Build);
-        printf("MKL Build:%d\n", build);
-    }
-    return build;
-}
-
-bool enableMKLWarnGenerated() {
-  return false;
-}
-#endif  // MSHADOW_USE_MKL2017
diff --git a/src/operator/mkl/mkl_cppwrapper.h b/src/operator/mkl/mkl_cppwrapper.h
deleted file mode 100644
index 7d66f20..0000000
--- a/src/operator/mkl/mkl_cppwrapper.h
+++ /dev/null
@@ -1,1020 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_cppwrapper.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_CPPWRAPPER_H_
-#define MXNET_OPERATOR_MKL_MKL_CPPWRAPPER_H_
-
-
-#include <stdarg.h>
-#include <stddef.h>
-#if MXNET_USE_MKL2017 == 1
-#include "mkl_dnn_types.h"
-#include "mkl_dnn.h"
-#include "mkl_version.h"
-
-
-extern int getMKLBuildDate();
-extern bool enableMKLWarnGenerated();
-
-
-template <typename Dtype> inline dnnError_t dnnLayoutCreate(
-    dnnLayout_t *pLayout, size_t dimension, const size_t size[], const size_t strides[]);
-template <> inline dnnError_t dnnLayoutCreate<float>(
-    dnnLayout_t *pLayout, size_t dimension, const size_t size[], const size_t strides[]) {
-    return dnnLayoutCreate_F32(pLayout, dimension, size, strides);
-}
-template <> inline dnnError_t dnnLayoutCreate<double>(
-    dnnLayout_t *pLayout, size_t dimension, const size_t size[], const size_t strides[]) {
-    return dnnLayoutCreate_F64(pLayout, dimension, size, strides);
-}
-
-template <typename Dtype> inline dnnError_t dnnLayoutCreateFromPrimitive(
-    dnnLayout_t *pLayout, const dnnPrimitive_t primitive, dnnResourceType_t type);
-template <> inline dnnError_t dnnLayoutCreateFromPrimitive<float>(
-    dnnLayout_t *pLayout, const dnnPrimitive_t primitive, dnnResourceType_t type) {
-    return dnnLayoutCreateFromPrimitive_F32(pLayout, primitive, type);
-}
-template <> inline dnnError_t dnnLayoutCreateFromPrimitive<double>(
-    dnnLayout_t *pLayout, const dnnPrimitive_t primitive, dnnResourceType_t type) {
-    return dnnLayoutCreateFromPrimitive_F64(pLayout, primitive, type);
-}
-
-template <typename Dtype> inline size_t dnnLayoutGetMemorySize(
-    const dnnLayout_t layout);
-template <> inline size_t dnnLayoutGetMemorySize<float>(
-    const dnnLayout_t layout) {
-    return dnnLayoutGetMemorySize_F32(layout);
-}
-template <> inline size_t dnnLayoutGetMemorySize<double>(
-    const dnnLayout_t layout) {
-    return dnnLayoutGetMemorySize_F64(layout);
-}
-
-template <typename Dtype> inline int dnnLayoutCompare(
-    const dnnLayout_t l1, const dnnLayout_t l2);
-template <> inline int dnnLayoutCompare<float>(
-    const dnnLayout_t l1, const dnnLayout_t l2) {
-    return dnnLayoutCompare_F32(l1, l2);
-}
-template <> inline int dnnLayoutCompare<double>(
-    const dnnLayout_t l1, const dnnLayout_t l2) {
-    return dnnLayoutCompare_F64(l1, l2);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnAllocateBuffer(
-    void **pPtr, dnnLayout_t layout);
-template <> inline dnnError_t dnnAllocateBuffer<float>(
-    void **pPtr, dnnLayout_t layout) {
-    return dnnAllocateBuffer_F32(pPtr, layout);
-}
-template <> inline dnnError_t dnnAllocateBuffer<double>(
-    void **pPtr, dnnLayout_t layout) {
-    return dnnAllocateBuffer_F64(pPtr, layout);
-}
-
-template <typename Dtype> inline dnnError_t dnnReleaseBuffer(
-    void *ptr);
-template <> inline dnnError_t dnnReleaseBuffer<float>(
-    void *ptr) {
-    return dnnReleaseBuffer_F32(ptr);
-}
-template <> inline dnnError_t dnnReleaseBuffer<double>(
-    void *ptr) {
-    return dnnReleaseBuffer_F64(ptr);
-}
-
-template <typename Dtype> inline dnnError_t dnnLayoutDelete(
-    dnnLayout_t layout);
-template <> inline dnnError_t dnnLayoutDelete<float>(
-    dnnLayout_t layout) {
-    return dnnLayoutDelete_F32(layout);
-}
-template <> inline dnnError_t dnnLayoutDelete<double>(
-    dnnLayout_t layout) {
-    return dnnLayoutDelete_F64(layout);
-}
-
-template <typename Dtype> inline dnnError_t dnnPrimitiveAttributesCreate(
-    dnnPrimitiveAttributes_t *attributes);
-template <> inline dnnError_t dnnPrimitiveAttributesCreate<float>(
-    dnnPrimitiveAttributes_t *attributes) {
-    return dnnPrimitiveAttributesCreate_F32(attributes);
-}
-template <> inline dnnError_t dnnPrimitiveAttributesCreate<double>(
-    dnnPrimitiveAttributes_t *attributes) {
-    return dnnPrimitiveAttributesCreate_F64(attributes);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnPrimitiveAttributesDestroy(
-    dnnPrimitiveAttributes_t attributes);
-template <> inline dnnError_t dnnPrimitiveAttributesDestroy<float>(
-    dnnPrimitiveAttributes_t attributes) {
-    return dnnPrimitiveAttributesDestroy_F32(attributes);
-}
-template <> inline dnnError_t dnnPrimitiveAttributesDestroy<double>(
-    dnnPrimitiveAttributes_t attributes) {
-    return dnnPrimitiveAttributesDestroy_F64(attributes);
-}
-
-template <typename Dtype> inline dnnError_t dnnPrimitiveGetAttributes(
-    dnnPrimitive_t primitive,
-    dnnPrimitiveAttributes_t *attributes);
-template <> inline dnnError_t dnnPrimitiveGetAttributes<float>(
-    dnnPrimitive_t primitive,
-    dnnPrimitiveAttributes_t *attributes) {
-    return dnnPrimitiveGetAttributes_F32(primitive, attributes);
-}
-template <> inline dnnError_t dnnPrimitiveGetAttributes<double>(
-    dnnPrimitive_t primitive,
-    dnnPrimitiveAttributes_t *attributes) {
-    return dnnPrimitiveGetAttributes_F64(primitive, attributes);
-}
-
-template <typename Dtype> inline dnnError_t dnnExecute(
-    dnnPrimitive_t primitive, void *resources[]);
-template <> inline dnnError_t dnnExecute<float>(
-    dnnPrimitive_t primitive, void *resources[]) {
-    return dnnExecute_F32(primitive, resources);
-}
-template <> inline dnnError_t dnnExecute<double>(
-    dnnPrimitive_t primitive, void *resources[]) {
-    return dnnExecute_F64(primitive, resources);
-}
-
-template <typename Dtype> inline dnnError_t dnnExecuteAsync(
-    dnnPrimitive_t primitive, void *resources[]);
-template <> inline dnnError_t dnnExecuteAsync<float>(
-    dnnPrimitive_t primitive, void *resources[]) {
-    return dnnExecuteAsync_F32(primitive, resources);
-}
-template <> inline dnnError_t dnnExecuteAsync<double>(
-    dnnPrimitive_t primitive, void *resources[]) {
-    return dnnExecuteAsync_F64(primitive, resources);
-}
-
-template <typename Dtype> inline dnnError_t dnnWaitFor(
-    dnnPrimitive_t primitive);
-template <> inline dnnError_t dnnWaitFor<float>(
-    dnnPrimitive_t primitive) {
-    return dnnWaitFor_F32(primitive);
-}
-template <> inline dnnError_t dnnWaitFor<double>(
-    dnnPrimitive_t primitive) {
-    return dnnWaitFor_F64(primitive);
-}
-
-template <typename Dtype> inline dnnError_t dnnDelete(
-    dnnPrimitive_t primitive);
-template <> inline dnnError_t dnnDelete<float>(
-    dnnPrimitive_t primitive) {
-    return dnnDelete_F32(primitive);
-}
-template <> inline dnnError_t dnnDelete<double>(
-    dnnPrimitive_t primitive) {
-    return dnnDelete_F64(primitive);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnConversionCreate(
-    dnnPrimitive_t* pConversion, const dnnLayout_t from, const dnnLayout_t to);
-template <> inline dnnError_t dnnConversionCreate<float>(
-    dnnPrimitive_t* pConversion, const dnnLayout_t from, const dnnLayout_t to) {
-    return dnnConversionCreate_F32(pConversion, from, to);
-}
-template <> inline dnnError_t dnnConversionCreate<double>(
-    dnnPrimitive_t* pConversion, const dnnLayout_t from, const dnnLayout_t to) {
-    return dnnConversionCreate_F64(pConversion, from, to);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnConversionExecute(
-    dnnPrimitive_t conversion, void *from, void *to);
-template <> inline dnnError_t dnnConversionExecute<float>(
-    dnnPrimitive_t conversion, void *from, void *to) {
-    return dnnConversionExecute_F32(conversion, from, to);
-}
-template <> inline dnnError_t dnnConversionExecute<double>(
-    dnnPrimitive_t conversion, void *from, void *to) {
-    return dnnConversionExecute_F64(conversion, from, to);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnConvolutionCreateForward(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnConvolutionCreateForward<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateForward_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <> inline dnnError_t dnnConvolutionCreateForward<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateForward_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnConvolutionCreateForwardBias(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnConvolutionCreateForwardBias<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateForwardBias_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnConvolutionCreateForwardBias<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateForwardBias_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnConvolutionCreateBackwardData(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnConvolutionCreateBackwardData<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateBackwardData_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnConvolutionCreateBackwardData<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateBackwardData_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnConvolutionCreateBackwardFilter(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnConvolutionCreateBackwardFilter<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateBackwardFilter_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnConvolutionCreateBackwardFilter<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t srcSize[], const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnConvolutionCreateBackwardFilter_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnConvolutionCreateBackwardBias(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t dstSize[]);
-template <> inline dnnError_t dnnConvolutionCreateBackwardBias<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t dstSize[]) {
-    return dnnConvolutionCreateBackwardBias_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, dstSize);
-}
-template <> inline dnnError_t dnnConvolutionCreateBackwardBias<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t dimension, const size_t dstSize[]) {
-    return dnnConvolutionCreateBackwardBias_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               dimension, dstSize);
-}
-
-template <typename Dtype> inline dnnError_t dnnGroupsConvolutionCreateForward(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnGroupsConvolutionCreateForward<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateForward_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnGroupsConvolutionCreateForward<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateForward_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnGroupsConvolutionCreateForwardBias(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnGroupsConvolutionCreateForwardBias<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateForwardBias_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnGroupsConvolutionCreateForwardBias<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateForwardBias_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnGroupsConvolutionCreateBackwardData(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardData<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateBackwardData_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardData<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateBackwardData_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnGroupsConvolutionCreateBackwardFilter(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardFilter<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateBackwardFilter_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardFilter<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t srcSize[],
-    const size_t dstSize[], const size_t filterSize[],
-    const size_t convolutionStrides[], const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnGroupsConvolutionCreateBackwardFilter_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, srcSize, dstSize, filterSize,
-               convolutionStrides, inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnGroupsConvolutionCreateBackwardBias(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t dstSize[]);
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardBias<float>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t dstSize[]) {
-    return dnnGroupsConvolutionCreateBackwardBias_F32(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, dstSize);
-}
-template <> inline dnnError_t dnnGroupsConvolutionCreateBackwardBias<double>(
-    dnnPrimitive_t* pConvolution,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t algorithm,
-    size_t groups, size_t dimension, const size_t dstSize[]) {
-    return dnnGroupsConvolutionCreateBackwardBias_F64(
-               pConvolution,
-               attributes,
-               algorithm,
-               groups, dimension, dstSize);
-}
-
-template <typename Dtype> inline dnnError_t dnnReLUCreateForward(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float negativeSlope);
-template <> inline dnnError_t dnnReLUCreateForward<float>(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float negativeSlope) {
-    return dnnReLUCreateForward_F32(
-               pRelu,
-               attributes,
-               dataLayout, negativeSlope);
-}
-template <> inline dnnError_t dnnReLUCreateForward<double>(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float negativeSlope) {
-    return dnnReLUCreateForward_F64(
-               pRelu,
-               attributes,
-               dataLayout, negativeSlope);
-}
-
-template <typename Dtype> inline dnnError_t dnnReLUCreateBackward(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout, float negativeSlope);
-template <> inline dnnError_t dnnReLUCreateBackward<float>(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout, float negativeSlope) {
-    return dnnReLUCreateBackward_F32(
-               pRelu,
-               attributes,
-               diffLayout, dataLayout, negativeSlope);
-}
-template <> inline dnnError_t dnnReLUCreateBackward<double>(
-    dnnPrimitive_t* pRelu,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout, float negativeSlope) {
-    return dnnReLUCreateBackward_F64(
-               pRelu,
-               attributes,
-               diffLayout, dataLayout, negativeSlope);
-}
-
-template <typename Dtype> inline dnnError_t dnnLRNCreateForward(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, size_t kernel_size, float alpha, float beta, float k);
-template <> inline dnnError_t dnnLRNCreateForward<float>(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, size_t kernel_size, float alpha, float beta, float k) {
-    return dnnLRNCreateForward_F32(
-               pLrn,
-               attributes,
-               dataLayout, kernel_size, alpha, beta, k);
-}
-template <> inline dnnError_t dnnLRNCreateForward<double>(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, size_t kernel_size, float alpha, float beta, float k) {
-    return dnnLRNCreateForward_F64(
-               pLrn,
-               attributes,
-               dataLayout, kernel_size, alpha, beta, k);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnLRNCreateBackward(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout,
-    size_t kernel_size, float alpha, float beta, float k);
-template <> inline dnnError_t dnnLRNCreateBackward<float>(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout,
-    size_t kernel_size, float alpha, float beta, float k) {
-    return dnnLRNCreateBackward_F32(
-               pLrn,
-               attributes,
-               diffLayout, dataLayout, kernel_size, alpha, beta, k);
-}
-template <> inline dnnError_t dnnLRNCreateBackward<double>(
-    dnnPrimitive_t* pLrn,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t diffLayout, const dnnLayout_t dataLayout,
-    size_t kernel_size, float alpha, float beta, float k) {
-    return dnnLRNCreateBackward_F64(
-               pLrn,
-               attributes,
-               diffLayout, dataLayout, kernel_size, alpha, beta, k);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnPoolingCreateForward(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnPoolingCreateForward<float>(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnPoolingCreateForward_F32(
-               pPooling,
-               attributes,
-               op,
-               srcLayout,
-               kernelSize, kernelStride,
-               inputOffset, border_type);
-}
-template <> inline dnnError_t dnnPoolingCreateForward<double>(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnPoolingCreateForward_F64(
-               pPooling,
-               attributes,
-               op,
-               srcLayout,
-               kernelSize, kernelStride,
-               inputOffset, border_type);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnPoolingCreateBackward(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type);
-template <> inline dnnError_t dnnPoolingCreateBackward<float>(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnPoolingCreateBackward_F32(
-               pPooling,
-               attributes,
-               op,
-               srcLayout,
-               kernelSize, kernelStride,
-               inputOffset, border_type);
-}
-template <> inline dnnError_t dnnPoolingCreateBackward<double>(
-    dnnPrimitive_t* pPooling,
-    dnnPrimitiveAttributes_t attributes,
-    dnnAlgorithm_t op,
-    const dnnLayout_t srcLayout,
-    const size_t kernelSize[], const size_t kernelStride[],
-    const int inputOffset[], const dnnBorder_t border_type) {
-    return dnnPoolingCreateBackward_F64(
-               pPooling,
-               attributes,
-               op,
-               srcLayout,
-               kernelSize, kernelStride,
-               inputOffset, border_type);
-}
-
-template <typename Dtype> inline dnnError_t dnnConcatCreate(
-    dnnPrimitive_t *pConcat,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src[]);
-template <> inline dnnError_t dnnConcatCreate<float>(
-    dnnPrimitive_t *pConcat,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src[]) {
-    return dnnConcatCreate_F32(
-               pConcat,
-               attributes,
-               N,
-               src);
-}
-template <> inline dnnError_t dnnConcatCreate<double>(
-    dnnPrimitive_t *pConcat,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src[]) {
-    return dnnConcatCreate_F64(
-               pConcat,
-               attributes,
-               N,
-               src);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnSplitCreate(
-    dnnPrimitive_t *pSplit,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src,
-    size_t dst[]);
-template <> inline dnnError_t dnnSplitCreate<float>(
-    dnnPrimitive_t *pSplit,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src,
-    size_t dst[]) {
-    return dnnSplitCreate_F32(
-               pSplit,
-               attributes,
-               N,
-               src,
-               dst);
-}
-template <> inline dnnError_t dnnSplitCreate<double>(
-    dnnPrimitive_t *pSplit,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t N,
-    dnnLayout_t src,
-    size_t dst[]) {
-    return dnnSplitCreate_F64(
-               pSplit,
-               attributes,
-               N,
-               src,
-               dst);
-}
-
-template <typename Dtype> inline dnnError_t dnnSumCreate(
-    dnnPrimitive_t *pSum,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t nSummands, dnnLayout_t layout, Dtype *coefficients);
-template <> inline dnnError_t dnnSumCreate<float>(
-    dnnPrimitive_t *pSum,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t nSummands, dnnLayout_t layout, float *coefficients) {
-    return dnnSumCreate_F32(
-               pSum,
-               attributes,
-               nSummands,
-               layout, coefficients);
-}
-template <> inline dnnError_t dnnSumCreate<double>(
-    dnnPrimitive_t *pSum,
-    dnnPrimitiveAttributes_t attributes,
-    const size_t nSummands, dnnLayout_t layout, double *coefficients) {
-    return dnnSumCreate_F64(
-               pSum,
-               attributes,
-               nSummands,
-               layout, coefficients);
-}
-
-template <typename Dtype> inline dnnError_t dnnBatchNormalizationCreateForward_v2(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags);
-
-template <> inline dnnError_t dnnBatchNormalizationCreateForward_v2<float>(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags) {
-    return dnnBatchNormalizationCreateForward_v2_F32(
-               pBatchNormalization,
-               attributes,
-               dataLayout, eps, flags);
-}
-template <> inline dnnError_t dnnBatchNormalizationCreateForward_v2<double>(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags) {
-    return dnnBatchNormalizationCreateForward_v2_F64(
-               pBatchNormalization,
-               attributes,
-               dataLayout, eps, flags);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnBatchNormalizationCreateBackward_v2(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags);
-
-template <> inline  dnnError_t dnnBatchNormalizationCreateBackward_v2<float>(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags) {
-    return dnnBatchNormalizationCreateBackward_v2_F32(
-               pBatchNormalization,
-               attributes,
-               dataLayout, eps, flags);
-}
-
-template <> inline dnnError_t dnnBatchNormalizationCreateBackward_v2<double>(
-    dnnPrimitive_t* pBatchNormalization,
-    dnnPrimitiveAttributes_t attributes,
-    const dnnLayout_t dataLayout, float eps,
-    int flags) {
-    return dnnBatchNormalizationCreateBackward_v2_F64(
-               pBatchNormalization,
-               attributes,
-               dataLayout, eps, flags);
-}
-
-template <typename Dtype> inline dnnError_t dnnInnerProductCreateForward(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels);
-template <> inline dnnError_t dnnInnerProductCreateForward<float>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateForward_F32(pInnerProduct,
-                                            attributes, dimensions,
-                                            srcSize, outputChannels);
-}
-template <> inline dnnError_t dnnInnerProductCreateForward<double>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateForward_F64(pInnerProduct,
-                                            attributes, dimensions,
-                                            srcSize, outputChannels);
-}
-
-template <typename Dtype> inline dnnError_t dnnInnerProductCreateForwardBias(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels);
-
-template <> inline dnnError_t dnnInnerProductCreateForwardBias<float>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateForwardBias_F32(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-template <> inline dnnError_t dnnInnerProductCreateForwardBias<double>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateForwardBias_F64(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-
-
-template <typename Dtype> inline dnnError_t dnnInnerProductCreateBackwardData(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels);
-
-template <> inline dnnError_t dnnInnerProductCreateBackwardData<float>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateBackwardData_F32(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-template <> inline dnnError_t dnnInnerProductCreateBackwardData<double>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateBackwardData_F64(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-
-
-
-
-template <typename Dtype> inline dnnError_t dnnInnerProductCreateBackwardFilter(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels);
-
-template <> inline dnnError_t dnnInnerProductCreateBackwardFilter<float>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateBackwardFilter_F32(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-template <> inline dnnError_t dnnInnerProductCreateBackwardFilter<double>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t srcSize[],
-    size_t outputChannels) {
-    return dnnInnerProductCreateBackwardFilter_F64(pInnerProduct,
-            attributes, dimensions,
-            srcSize, outputChannels);
-}
-
-
-
-template <typename Dtype> inline dnnError_t dnnInnerProductCreateBackwardBias(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t dstSize[]);
-
-template <> inline dnnError_t dnnInnerProductCreateBackwardBias<float>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t dstSize[]) {
-    return dnnInnerProductCreateBackwardBias_F32(pInnerProduct,
-            attributes, dimensions,
-            dstSize);
-}
-template <> inline dnnError_t dnnInnerProductCreateBackwardBias<double>(
-    dnnPrimitive_t *pInnerProduct,
-    dnnPrimitiveAttributes_t attributes,
-    size_t dimensions,
-    const size_t dstSize[]) {
-    return dnnInnerProductCreateBackwardBias_F64(pInnerProduct,
-            attributes, dimensions,
-            dstSize);
-}
-#endif  // #MXNET_USE_MKL2017 == 1
-#endif  // MXNET_OPERATOR_MKL_MKL_CPPWRAPPER_H_
diff --git a/src/operator/mkl/mkl_elementwise_copy-inl.h b/src/operator/mkl/mkl_elementwise_copy-inl.h
deleted file mode 100644
index 48c9312..0000000
--- a/src/operator/mkl/mkl_elementwise_copy-inl.h
+++ /dev/null
@@ -1,69 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_elementwise-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_COPY_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_COPY_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <cstring>
-#include <map>
-#include <string>
-#include <vector>
-#include <utility>
-#include "../operator_common.h"
-#include "../mshadow_op.h"
-#include "./mkl_util-inl.h"
-
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-void MKLIdentityCompute(const nnvm::NodeAttrs& attrs,
-  const OpContext& ctx,
-  const std::vector<TBlob>& inputs,
-  const std::vector<OpReqType>& req,
-  const std::vector<TBlob>& outputs) {
-  if (!req[0]) return;
-#if MKL_EXPERIMENTAL == 1
-  if (op::mkl_prv_data<DType>(inputs[0])) {
-    std::shared_ptr<MKLMemHolder> in_data_mem = inputs[0].Mkl_mem_;
-    // User copy to avoid potential problem
-    std::shared_ptr<MKLData<DType> > top_data = MKLData<DType>::create();
-    std::shared_ptr<MKLMemHolder> top_mem = outputs[0].Mkl_mem_;
-    top_data->copy_from(in_data_mem);
-    top_mem->set_prv_descriptor(top_data);
-    return;
-  }
-#endif
-  int in_blob_size = inputs[0].Size();
-  int out_blob_size = outputs[0].Size();
-  CHECK_EQ(in_blob_size, out_blob_size) << "MKLIdentityCompute CPU Size not Match ";
-  memcpy(outputs[0].dptr_, inputs[0].dptr_, in_blob_size * sizeof(DType));
-}
-
-
-
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_COPY_INL_H_
diff --git a/src/operator/mkl/mkl_elementwise_sum-inl.h b/src/operator/mkl/mkl_elementwise_sum-inl.h
deleted file mode 100644
index d313fd1..0000000
--- a/src/operator/mkl/mkl_elementwise_sum-inl.h
+++ /dev/null
@@ -1,117 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_elementwise-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_SUM_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_SUM_INL_H_
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <cstring>
-#include <map>
-#include <string>
-#include <vector>
-#include <utility>
-#include "../operator_common.h"
-#include "../mshadow_op.h"
-#include "./mkl_util-inl.h"
-
-
-namespace mxnet {
-namespace op {
-template<typename xpu, typename DType>
-static void LayerSetUp(const std::vector<mshadow::Tensor<xpu, 1, DType> > &data,
-  size_t data_shape_size,
-  std::shared_ptr<MKLData<DType> > fwd_top_data) {
-  // Whether to use an asymptotically slower (for >2 inputs) but stabler method
-  // of computing the gradient for the PROD operation. (No effect for SUM op.)
-  // stable_prod_grad_ = 1;
-  size_t dim_src = data_shape_size;
-  size_t *sizes_src = new size_t[dim_src];
-  size_t *strides_src = new size_t[dim_src];
-  for (size_t d = 0; d < dim_src; ++d) {
-    sizes_src[d] = data[0].shape_[dim_src - d - 1];
-    strides_src[d] = (d == 0) ? 1 : strides_src[d - 1] * sizes_src[d - 1];
-  }
-
-  fwd_top_data->create_user_layout(dim_src, sizes_src, strides_src);
-  delete[] sizes_src;
-  delete[] strides_src;
-}
-
-template<typename xpu, typename DType>
-void MKLElementWiseSumCompute_(const nnvm::NodeAttrs& attrs,
-  const OpContext& ctx,
-  const std::vector<TBlob>& in_data,
-  const std::vector<OpReqType>& req,
-  const std::vector<TBlob>& out_data) {
-  using namespace mshadow;
-  using namespace mshadow::expr;
-  if (req[0] == kNullOp) return;
-  size_t size = in_data.size();
-  Stream<xpu> *s = ctx.get_stream<xpu>();
-  std::vector<Tensor<xpu, 1, DType> > data(size);
-  Tensor<xpu, 1, DType> out = out_data[0].FlatTo1D<xpu, DType>(s);
-  bool in_place_flag = false;
-  int in_place_idx = 0;
-
-  for (size_t i = 0; i < size; ++i) {
-    data[i]  = in_data[i].FlatTo1D<xpu, DType>(s);
-    if (data[i].dptr_ == out.dptr_) {
-      in_place_idx = i;
-      in_place_flag = true;
-    }
-  }
-  std::shared_ptr<MKLData<DType> > fwd_top_data = MKLData<DType>::create();
-  std::vector<DType> coeffs_  = std::vector<DType>(data.size(), 1);
-  LayerSetUp(data, 1, fwd_top_data);
-
-
-  dnnError_t e;
-  void *eltwise_res[dnnResourceNumber];
-  dnnPrimitive_t sumPrimitive = NULL;
-  e = dnnSumCreate<DType>(&sumPrimitive, NULL, size, fwd_top_data->layout_usr,
-    &coeffs_[0]);
-  CHECK_EQ(e, E_SUCCESS);
-
-  eltwise_res[dnnResourceDst] = reinterpret_cast<void*>(const_cast<DType*>(out.dptr_));
-  eltwise_res[dnnResourceMultipleSrc] =
-    reinterpret_cast<void *>(reinterpret_cast<void *>(in_data[in_place_idx].dptr_));
-  for (size_t i = 1; i < size; ++i) {
-    if (i == in_place_idx) continue;
-    eltwise_res[dnnResourceMultipleSrc + i] =
-      reinterpret_cast<void *>(reinterpret_cast<void *>(in_data[i].dptr_));
-  }
-
-  e = dnnExecute<DType>(sumPrimitive, eltwise_res);
-  CHECK_EQ(e, E_SUCCESS);
-
-  if (sumPrimitive != NULL) {
-    dnnDelete<DType>(sumPrimitive);
-    sumPrimitive = NULL;
-  }
-}
-
-
-
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_ELEMENTWISE_SUM_INL_H_
diff --git a/src/operator/mkl/mkl_fully_connected-inl.h b/src/operator/mkl/mkl_fully_connected-inl.h
deleted file mode 100644
index 5e29670..0000000
--- a/src/operator/mkl/mkl_fully_connected-inl.h
+++ /dev/null
@@ -1,192 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_fully_connected-inl.h
-* \brief
-* \author zhenlin.luo@intel.com
-*          lingyan.guo@intel.com
-*         
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_
-#include <string>
-#include <algorithm>
-#include <vector>
-#include "../activation-inl.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-class MKLFullyConnectedOp : public Operator {
- public:
-  explicit MKLFullyConnectedOp(const FullyConnectedParam& p,
-                               const std::vector<TShape>& in_shapes,
-                               const std::vector<TShape>& out_shapes):
-    param_(p) {
-    LayerSetUp(in_shapes, out_shapes);
-  }
-
-  ~MKLFullyConnectedOp() {
-    dnnDelete<DType>(fullyConnectedFwd);
-    dnnDelete<DType>(fullyConnectedBwdData);
-    dnnDelete<DType>(fullyConnectedBwdFilter);
-    dnnDelete<DType>(fullyConnectedBwdBias);
-  }
-  static std::string getName() {
-    return "MKLFullyConnectedOp";
-  }
-
- private:
-  void LayerSetUp(const std::vector<TShape>& in_shapes,
-                  const std::vector<TShape>& out_shapes) {
-    const TShape& ishape = in_shapes[fullc::kData];
-
-    const size_t dim = 4;
-    const size_t src_sizes[4] = {1, 1, ishape.ProdShape(1, ishape.ndim()), ishape[0]};
-    const size_t dst_sizes[2] = {param_.num_hidden, ishape[0]};
-    const size_t output_channels = param_.num_hidden;
-
-    dnnPrimitiveAttributes_t attributes = NULL;
-    MKLDNN_CALL(dnnPrimitiveAttributesCreate<DType>(&attributes));
-    if (!param_.no_bias) {
-      MKLDNN_CALL(dnnInnerProductCreateForwardBias<DType>(
-            &fullyConnectedFwd,
-            attributes,
-            dim,
-            src_sizes,
-            output_channels));
-    } else {
-      MKLDNN_CALL(dnnInnerProductCreateForward<DType>(
-            &fullyConnectedFwd,
-            attributes,
-            dim,
-            src_sizes,
-            output_channels));
-    }
-    MKLDNN_CALL(dnnInnerProductCreateBackwardData<DType>(
-          &fullyConnectedBwdData,
-          attributes,
-          dim,
-          src_sizes,
-          output_channels));
-    MKLDNN_CALL(dnnInnerProductCreateBackwardFilter<DType>(
-          &fullyConnectedBwdFilter,
-          attributes,
-          dim,
-          src_sizes,
-          output_channels));
-    if (!param_.no_bias) {
-      MKLDNN_CALL(dnnInnerProductCreateBackwardBias<DType>(
-            &fullyConnectedBwdBias,
-            attributes,
-            2,
-            dst_sizes));
-    }
-    // TODO(minjie): Shouldn't `attributes` be destroyed?
-  }
-
-
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    void* res_fullyConnected[dnnResourceNumber];
-    if (req[fullc::kOut] == kNullOp) return;
-    CHECK_EQ(req[fullc::kOut], kWriteTo);
-    CHECK_EQ(in_data.size(), param_.no_bias ? 2 : 3);
-    CHECK_EQ(out_data.size(), 1);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-
-    const TShape& ishape = in_data[fullc::kData].shape_;
-    const TShape& oshape = out_data[fullc::kOut].shape_;
-
-    Tensor<xpu, 4, DType> data;
-    Tensor<xpu, 4, DType> out;
-
-    Shape4(in_data[fullc::kData].shape_[0], in_data[fullc::kData].shape_[1], 1, 1);
-
-    Shape<4> dshape = Shape4(ishape[0], ishape.ProdShape(1, ishape.ndim()), 1, 1);
-    Shape<4> odshape = Shape4(oshape[0], oshape.ProdShape(1, oshape.ndim()), 1, 1);
-
-    data = in_data[fullc::kData].get_with_shape<xpu, 4, DType>(dshape, s);
-    out = out_data[fullc::kOut].get_with_shape<xpu, 4, DType>(odshape, s);
-    res_fullyConnected[dnnResourceSrc] =
-      reinterpret_cast<void *>(in_data[fullc::kData].dptr_);
-    res_fullyConnected[dnnResourceDst] =
-      reinterpret_cast<void *>(out_data[fullc::kOut].dptr_);
-    res_fullyConnected[dnnResourceFilter] =
-      reinterpret_cast<void *>(in_data[fullc::kWeight].dptr_);
-    if (!param_.no_bias) {
-      res_fullyConnected[dnnResourceBias] = reinterpret_cast<void *>(in_data[fullc::kBias].dptr_);
-    }
-
-    MKLDNN_CALL(dnnExecute<DType>(fullyConnectedFwd, res_fullyConnected));
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    void* res_fullyConnected[dnnResourceNumber];
-    CHECK_EQ(out_grad.size(), 1);
-    const size_t expected = param_.no_bias ? 2 : 3;
-    CHECK(in_data.size() == expected && in_grad.size() == expected);
-    CHECK_EQ(req.size(), expected);
-    res_fullyConnected[dnnResourceSrc] =
-      reinterpret_cast<void *>(in_data[fullc::kData].dptr_);
-    res_fullyConnected[dnnResourceFilter] =
-      reinterpret_cast<void *>(in_data[fullc::kWeight].dptr_);
-
-    res_fullyConnected[dnnResourceDiffDst] =
-      reinterpret_cast<void *>(out_grad[fullc::kOut].dptr_);
-    res_fullyConnected[dnnResourceDiffSrc] =
-      reinterpret_cast<void *>(in_grad[fullc::kData].dptr_);
-    res_fullyConnected[dnnResourceDiffFilter] =
-      reinterpret_cast<void *>(in_grad[fullc::kWeight].dptr_);
-    if (!param_.no_bias) {
-      res_fullyConnected[dnnResourceDiffBias] =
-        reinterpret_cast<void *>(in_grad[fullc::kBias].dptr_);
-    }
-    MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdFilter, res_fullyConnected));
-    if (!param_.no_bias) {
-      MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdBias, res_fullyConnected));
-    }
-    MKLDNN_CALL(dnnExecute<DType>(fullyConnectedBwdData, res_fullyConnected));
-  }
-
- private:
-  dnnPrimitive_t fullyConnectedFwd{nullptr};
-  dnnPrimitive_t fullyConnectedBwdData{nullptr};
-  dnnPrimitive_t fullyConnectedBwdFilter{nullptr};
-  dnnPrimitive_t fullyConnectedBwdBias{nullptr};
-  const FullyConnectedParam param_;
-};  // class MKLFullyConnectedOp
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // MXNET_OPERATOR_MKL_MKL_FULLY_CONNECTED_INL_H_
diff --git a/src/operator/mkl/mkl_lrn-inl.h b/src/operator/mkl/mkl_lrn-inl.h
deleted file mode 100644
index 90dfad5..0000000
--- a/src/operator/mkl/mkl_lrn-inl.h
+++ /dev/null
@@ -1,265 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_lrn-inl.h
-* \brief
-* \author zhenlin.luo@intel.com
-*         lingyan.guo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_LRN_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_LRN_INL_H_
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../operator_common.h"
-#include "../mshadow_op.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-class MKLLRNOp : public Operator {
- public:
-  static std::string getName() {
-    return "MKLLRNOp";
-  }
-
-  explicit MKLLRNOp(LRNParam param) :
-    lrnFwd(static_cast<dnnPrimitive_t>(NULL)),
-    lrnBwd(static_cast<dnnPrimitive_t>(NULL)),
-    lrn_buffer_(NULL) {
-    this->param_ = param;
-    fwd_top_data_ = MKLData<DType>::create();
-    fwd_bottom_data_ = MKLData<DType>::create();
-    bwd_top_diff_ = MKLData<DType>::create();
-    bwd_bottom_diff_ = MKLData<DType>::create();
-    init_mkldnn_ = false;
-  }
-
-  virtual ~MKLLRNOp() {
-    if (lrnFwd != NULL) {
-      dnnDelete<DType>(lrnFwd);
-      lrnFwd = NULL;
-    }
-    if (lrnBwd != NULL) {
-      dnnDelete<DType>(lrnBwd);
-      lrnBwd = NULL;
-    }
-    dnnReleaseBuffer<DType>(lrn_buffer_);
-  }
-
- private:
-  void LayerSetup(const mshadow::Tensor<xpu, 4, DType> &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out) {
-    size_ = param_.nsize;
-    CHECK_EQ(size_ % 2, 1) << "LRN only supports odd values for local size";
-
-    alpha_ = param_.alpha;
-    beta_ = param_.beta;
-    k_ = param_.knorm;
-    size_t dim = 4, sizes[4], strides[4];
-    channels_ = data.shape_[1];
-    height_ = data.shape_[2];
-    width_ = data.shape_[3];
-    num_ = data.shape_[0];
-    sizes[0] = width_;
-    sizes[1] = height_;
-    sizes[2] = channels_;
-    sizes[3] = num_;
-
-    strides[0] = 1;
-    strides[1] = sizes[0];
-    strides[2] = sizes[0] * sizes[1];
-    strides[3] = sizes[0] * sizes[1] * sizes[2];
-
-    fwd_bottom_data_->name = "fwd_bottom_data_   @ " + getName();
-    fwd_top_data_->name = "fwd_top_data_      @ " + getName();
-    bwd_top_diff_->name = "bwd_top_diff_      @ " + getName();
-    bwd_bottom_diff_->name = "bwd_bottom_diff_   @ " + getName();
-
-    fwd_bottom_data_->create_user_layout(dim, sizes, strides);
-    fwd_top_data_->create_user_layout(dim, sizes, strides);
-    bwd_bottom_diff_->create_user_layout(dim, sizes, strides);
-    bwd_top_diff_->create_user_layout(dim, sizes, strides);
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 1U);
-    CHECK_EQ(out_data.size(), 2U);
-    CHECK_EQ(param_.nsize % 2, 1U) << "LRN only supports odd values for local_size";
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data = mkl_experimental_direct_get<xpu, 4, DType>(
-      in_data[lrn_enum::kData], s);
-    Tensor<xpu, 4, DType> out = mkl_experimental_direct_get<xpu, 4, DType>(
-      out_data[lrn_enum::kOut], s);
-    if (!init_mkldnn_) {
-      LayerSetup(data, out);
-      init_mkldnn_ = true;
-    }
-
-    const void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data =
-          reinterpret_cast<void*>(mkl_prv_data<DType>(in_data[lrn_enum::kData]));
-#endif
-#if MKL_EXPERIMENTAL == 1
-    if (NULL != bottom_data) {
-      if (lrnFwd == NULL) {
-        std::shared_ptr<MKLMemHolder> bottom_data_mem =
-          in_data[lrn_enum::kData].Mkl_mem_;
-        std::shared_ptr<PrvMemDescr> bottom_prv_descriptor =
-          bottom_data_mem->get_prv_descriptor();
-        CHECK_EQ(bottom_prv_descriptor->get_descr_type(),
-            PrvMemDescr::PRV_DESCR_MKL2017);
-        std::shared_ptr<MKLData<DType> > mem_descr
-          = std::static_pointer_cast<MKLData<DType>>(bottom_prv_descriptor);
-        CHECK(mem_descr != nullptr);
-        fwd_bottom_data_ = mem_descr;
-
-        dnnError_t e;
-        dnnLayout_t lrn_buffer_l = NULL;
-
-        e = dnnLRNCreateForward<DType>(&lrnFwd, NULL, fwd_bottom_data_->layout_int,
-                                       size_, alpha_, beta_, k_);
-        CHECK_EQ(e, E_SUCCESS);
-
-        fwd_top_data_->create_internal_layout(lrnFwd, dnnResourceDst);
-
-        e = dnnLRNCreateBackward<DType>(&lrnBwd, NULL,
-                                        fwd_bottom_data_->layout_int, fwd_bottom_data_->layout_int,
-                                        size_, alpha_, beta_, k_);
-        CHECK_EQ(e, E_SUCCESS);
-
-        e = dnnLayoutCreateFromPrimitive<DType>(
-              &lrn_buffer_l, lrnFwd, dnnResourceWorkspace);
-        CHECK_EQ(e, E_SUCCESS);
-        e = dnnAllocateBuffer<DType>(
-              reinterpret_cast<void **>(&lrn_buffer_), lrn_buffer_l);
-        CHECK_EQ(e, E_SUCCESS);
-        dnnLayoutDelete<DType>(lrn_buffer_l);
-
-        bwd_top_diff_->create_internal_layout(lrnBwd, dnnResourceDiffDst);
-        bwd_bottom_diff_->create_internal_layout(lrnBwd, dnnResourceDiffSrc);
-      }
-    }
-#endif
-    if (bottom_data == NULL) {
-      if (lrnFwd == NULL) {
-        dnnError_t e;
-        dnnLayout_t lrn_buffer_l = NULL;
-        e = dnnLRNCreateForward<DType>(&lrnFwd, NULL, fwd_bottom_data_->layout_usr,
-                                       size_, alpha_, beta_, k_);
-        CHECK_EQ(e, E_SUCCESS);
-
-        e = dnnLayoutCreateFromPrimitive<DType>(
-              &lrn_buffer_l, lrnFwd, dnnResourceWorkspace);
-        CHECK_EQ(e, E_SUCCESS);
-        e = dnnAllocateBuffer<DType>(
-              reinterpret_cast<void **>(&lrn_buffer_), lrn_buffer_l);
-        CHECK_EQ(e, E_SUCCESS);
-        dnnLayoutDelete<DType>(lrn_buffer_l);
-
-        e = dnnLRNCreateBackward<DType>(&lrnBwd, NULL,
-                                        fwd_bottom_data_->layout_usr, fwd_bottom_data_->layout_usr,
-                                        size_, alpha_, beta_, k_);
-        CHECK_EQ(e, E_SUCCESS);
-      }
-      bottom_data = data.dptr_;
-    }
-
-    dnnError_t e;
-    void* lrn_res[dnnResourceNumber];
-    lrn_res[dnnResourceSrc] = const_cast<void*>(bottom_data);
-
-    lrn_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(
-      out.dptr_, fwd_top_data_, out_data[lrn_enum::kOut]);
-    lrn_res[dnnResourceWorkspace] = lrn_buffer_;
-    e = dnnExecute<DType>(lrnFwd, lrn_res);
-    CHECK_EQ(e, E_SUCCESS);
-  }
-
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1);
-    CHECK_EQ(in_data.size(), 1);
-    CHECK_EQ(out_data.size(), 2);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> grad = mkl_experimental_direct_get<xpu, 4, DType>(
-      out_grad[lrn_enum::kOut], s);
-    Tensor<xpu, 4, DType> data = mkl_experimental_direct_get<xpu, 4, DType>(
-      in_data[lrn_enum::kData], s);
-    Tensor<xpu, 4, DType> grad_in = mkl_experimental_direct_get<xpu, 4, DType>(
-      in_grad[lrn_enum::kData], s);
-    dnnError_t e;
-    void* lrn_res[dnnResourceNumber];
-    lrn_res[dnnResourceDiffDst] =
-      bwd_top_diff_->get_converted_prv(grad.dptr_, true, out_grad[lrn_enum::kOut]);
-    lrn_res[dnnResourceWorkspace] = lrn_buffer_;
-    lrn_res[dnnResourceSrc] =
-      fwd_bottom_data_->get_converted_prv(data.dptr_, false, in_data[lrn_enum::kData]);
-
-    lrn_res[dnnResourceDiffSrc] = bwd_bottom_diff_->get_output_ptr(
-      grad_in.dptr_, bwd_bottom_diff_, in_grad[lrn_enum::kData]);
-    e = dnnExecute<DType>(lrnBwd, lrn_res);
-    CHECK_EQ(e, E_SUCCESS);
-  }
-
- private:
-  LRNParam param_;
-  int size_;
-  int pre_pad_;
-  DType alpha_;
-  DType beta_;
-  DType k_;
-  int num_;
-  int channels_;
-  int height_;
-  int width_;
-  bool init_mkldnn_;
-
- private:
-  dnnPrimitive_t lrnFwd, lrnBwd;
-  std::shared_ptr<MKLData<DType> > fwd_top_data_;
-  std::shared_ptr<MKLData<DType> > fwd_bottom_data_;
-
-  std::shared_ptr<MKLData<DType> > bwd_top_diff_;
-  std::shared_ptr<MKLData<DType> > bwd_bottom_diff_;
-
-  DType *lrn_buffer_;
-};  // class LocalResponseNormOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_LRN_INL_H_
-
diff --git a/src/operator/mkl/mkl_memory-inl.h b/src/operator/mkl/mkl_memory-inl.h
deleted file mode 100644
index 71af102..0000000
--- a/src/operator/mkl/mkl_memory-inl.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_memory-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_MEMORY_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_MEMORY_INL_H_
-
-
-#include <string>
-#include <vector>
-#include <memory>
-#include "mkl_cppwrapper.h"
-
-namespace mxnet {
-
-template <typename DType>
-struct MKLMemoryDescriptorBase : public PrvMemDescr,
- public std::enable_shared_from_this<MKLMemoryDescriptorBase<DType> > {
-    MKLMemoryDescriptorBase() : layout_usr(NULL), layout_int(NULL),
-    convert_to_int(NULL), convert_from_int(NULL), convert_prv2prv(NULL),
-    name("UNKNOWN"), internal_ptr(NULL) {}
-  virtual ~MKLMemoryDescriptorBase() {
-    dnnLayoutDelete<DType>(layout_usr);
-    dnnLayoutDelete<DType>(layout_int);
-    if (internal_ptr != NULL) {
-      dnnReleaseBuffer<DType>(internal_ptr);
-      internal_ptr = NULL;
-    }
-    if (convert_to_int != NULL) {
-      dnnDelete<DType>(convert_to_int);
-      convert_to_int = NULL;
-    }
-    if (convert_from_int != NULL) {
-      dnnDelete<DType>(convert_from_int);
-      convert_from_int = NULL;
-    }
-    if (convert_prv2prv != NULL) {
-      dnnDelete<DType>(convert_prv2prv);
-      convert_prv2prv = NULL;
-    }
-  }
-  std::shared_ptr<MKLMemoryDescriptorBase<DType> > get_shared_ptr() {
-    return this->shared_from_this();
-  }
-
-  dnnLayout_t layout_usr;
-  dnnLayout_t layout_int;
-  dnnPrimitive_t convert_to_int;
-  dnnPrimitive_t convert_from_int;
-  dnnPrimitive_t convert_prv2prv;
-  std::shared_ptr<MKLMemoryDescriptorBase<DType> > descr_prv2prv_conversion;
-
-
-  std::string name;  // for debugging purposes
-  void allocate() {
-    if (internal_ptr == NULL) {
-      int status = dnnAllocateBuffer<DType>(
-              reinterpret_cast<void **>(&internal_ptr), layout_int);
-      CHECK_EQ(status, E_SUCCESS)
-          << "Failed internal_ptr memory allocation with status "
-          << status << "\n";
-    }
-  }
-  virtual void* prv_ptr(bool allocate_when_uninit = true) {
-    if (internal_ptr == NULL && allocate_when_uninit)
-      allocate();
-    return internal_ptr;
-  }
-  inline bool conversion_needed() {
-    return (convert_to_int != NULL);
-  }
-  void create_conversions();
-  void create_internal_layout(const dnnPrimitive_t primitive,
-                dnnResourceType_t type);
-  void create_user_layout(size_t dimension, const size_t size[],
-              const size_t strides[]);
-  void create_layouts(
-    const dnnPrimitive_t primitive, dnnResourceType_t type,
-    size_t dimension, const size_t size[], const size_t strides[]);
-
-  virtual PrvDescrType get_descr_type() {
-    return PRV_DESCR_MKL2017;
-  }
-  virtual size_t prv_size() {
-    return dnnLayoutGetMemorySize<DType>(layout_int);
-  }
-  virtual size_t prv_count() {
-    return dnnLayoutGetMemorySize<DType>(layout_int) / sizeof(DType);
-  }
-  virtual void convert_from_prv(void* cpu_ptr);
-  virtual void convert_to_prv(void* cpu_ptr);
-  virtual bool layout_compare(std::shared_ptr<PrvMemDescr> other);
-  virtual void convert_from_other(std::shared_ptr<PrvMemDescr> other);
- protected:
-  DType* internal_ptr;
-};
-
-template <typename DType>
-struct MKLMemoryDescriptor : MKLMemoryDescriptorBase<DType> {
-  // The last get_converted_prv() argument is a hack for reusing
-  // in backward a conversion done already in the forward direction.
-  DType* get_converted_prv(DType *data_ptr, bool set_prv_ptr,
-      const TBlob &blob);
-  void* get_output_ptr(DType *data_ptr, std::shared_ptr<MKLMemoryDescriptor<DType> > self_ptr,
-    const TBlob &blob, bool in_place = false);
-  bool copy_from(std::shared_ptr<MKLMemHolder> dnn_chunk);
-  MKLMemoryDescriptor() {}
-};
-
-template <typename DType> struct MKLData : MKLMemoryDescriptor<DType> {
-  static std::shared_ptr<MKLData<DType> > create() {
-    return std::make_shared<MKLData<DType> >();
-  }
-};
-
-template struct MKLData<float>;
-template struct MKLData<double>;
-
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_MEMORY_INL_H_
diff --git a/src/operator/mkl/mkl_memory.cc b/src/operator/mkl/mkl_memory.cc
deleted file mode 100644
index 7682fe1..0000000
--- a/src/operator/mkl/mkl_memory.cc
+++ /dev/null
@@ -1,291 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_memory.cc
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#include "../operator_common.h"
-
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "mkl_memory-inl.h"
-#include "mkl_util-inl.h"
-
-namespace mxnet {
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::create_conversions() {
-  int status;
-  if (this->convert_from_int) {
-    status = dnnDelete<Dtype>(this->convert_from_int);
-    CHECK_EQ(status, E_SUCCESS);
-    this->convert_from_int = NULL;
-  }
-  if (this->convert_to_int) {
-    status = dnnDelete<Dtype>(this->convert_to_int);
-    CHECK_EQ(status, E_SUCCESS);
-    this->convert_to_int = NULL;
-  }
-  if (layout_int
-      && !dnnLayoutCompare<Dtype>(layout_usr, layout_int)) {
-    CHECK(layout_usr);
-    status = dnnConversionCreate<Dtype>(&convert_to_int, layout_usr,
-            layout_int);
-    CHECK_EQ(status, E_SUCCESS)
-            << "Failed creation convert_to_int with status "
-            << status << " for buffer: " << this->name << "\n";
-    status = dnnConversionCreate<Dtype>(&convert_from_int, layout_int,
-            layout_usr);
-    CHECK_EQ(status, E_SUCCESS)
-            << "Failed creation convert_from_int with status "
-            << status << " for buffer: " << this->name << "\n";
-  }
-}
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::create_internal_layout(
-    const dnnPrimitive_t primitive, dnnResourceType_t type) {
-  int status;
-  if (this->layout_int) {
-    status = dnnLayoutDelete<Dtype>(this->layout_int);
-    CHECK_EQ(status, E_SUCCESS);
-  }
-  status = dnnLayoutCreateFromPrimitive<Dtype>(
-      &this->layout_int, primitive, type);
-  CHECK_EQ(status, E_SUCCESS)
-      << "Failed dnnLayoutCreateFromPrimitive with status "
-      << status << " for buffer: " << this->name << "\n";
-
-  if (this->layout_usr)
-    this->create_conversions();
-}
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::create_user_layout(
-    size_t dimension, const size_t size[], const size_t strides[]) {
-  int status;
-  if (this->layout_usr) {
-    status = dnnLayoutDelete<Dtype>(this->layout_usr);
-    CHECK_EQ(status, E_SUCCESS);
-  }
-
-  status = dnnLayoutCreate<Dtype>(
-      &this->layout_usr, dimension, size, strides);
-  CHECK_EQ(status, E_SUCCESS) << "Failed dnnLayoutCreate with status "
-      << status << " for buffer: " << this->name << "\n";
-
-  if (this->layout_int)
-    this->create_conversions();
-}
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::create_layouts(
-    const dnnPrimitive_t primitive, dnnResourceType_t type,
-    size_t dimension, const size_t size[], const size_t strides[]) {
-  this->create_internal_layout(primitive, type);
-  this->create_user_layout(dimension, size, strides);
-}
-
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::convert_from_prv(void* cpu_ptr) {
-  CHECK(cpu_ptr);
-  CHECK(this->convert_from_int);
-  int status;
-  void *convert_resources[dnnResourceNumber];
-
-  convert_resources[dnnResourceFrom] = this->prv_ptr();
-  convert_resources[dnnResourceTo]   = cpu_ptr;
-  status = dnnExecute<Dtype>(this->convert_from_int, convert_resources);
-  CHECK_EQ(status, 0) << "Conversion from prv failed with status " << status;
-}
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::convert_to_prv(void* cpu_ptr) {
-  CHECK(cpu_ptr);
-  CHECK(this->convert_to_int);
-  int status;
-  void *convert_resources[dnnResourceNumber];
-
-  convert_resources[dnnResourceFrom] = cpu_ptr;
-  convert_resources[dnnResourceTo]   = this->prv_ptr();
-  status = dnnExecute<Dtype>(this->convert_to_int, convert_resources);
-  CHECK_EQ(status, 0) << "Conversion from prv failed with status " << status;
-}
-
-
-template <typename Dtype>
-bool MKLMemoryDescriptorBase<Dtype>::layout_compare(
-  std::shared_ptr<PrvMemDescr> other) {
-  CHECK_EQ(other->get_descr_type(),
-              PrvMemDescr::PRV_DESCR_MKL2017);
-  std::shared_ptr<MKLMemoryDescriptorBase<Dtype> >other_descr =
-    std::static_pointer_cast<MKLMemoryDescriptorBase<Dtype> >
-    (other);
-
-  if (dnnLayoutCompare<Dtype>(other_descr->layout_int,
-      this->layout_int))
-    return true;
-  else
-    return false;
-}
-
-template <typename Dtype>
-void MKLMemoryDescriptorBase<Dtype>::convert_from_other(
-  std::shared_ptr<PrvMemDescr> other) {
-    std::shared_ptr<MKLMemoryDescriptorBase<Dtype> > other_descr =
-        std::static_pointer_cast<MKLMemoryDescriptorBase<Dtype> >
-            (other);
-
-  int status;
-  dnnPrimitive_t convert;
-  status = dnnConversionCreate<Dtype>(&convert,
-    other_descr->layout_int, this->layout_int);
-
-  void *convert_resources[dnnResourceNumber];
-  convert_resources[dnnResourceFrom] = other_descr->prv_ptr();
-  convert_resources[dnnResourceTo]   = this->prv_ptr();
-  status = dnnExecute<Dtype>(convert, convert_resources);
-  CHECK_EQ(status, 0) << "Conversion from other failed with status "
-                      << status;
-
-  dnnDelete<Dtype>(convert);
-}
-
-
-template <typename Dtype>
-Dtype* MKLMemoryDescriptor<Dtype>::get_converted_prv(
-    Dtype *cpu_ptr, bool set_prv_ptr, const TBlob &blob) {
-  Dtype* prv_ptr = NULL;
-  std::shared_ptr<MKLMemHolder> dnn_chunk = NULL;
-#if MKL_EXPERIMENTAL == 1
-  dnn_chunk = blob.Mkl_mem_;
-#endif
-#if MKL_EXPERIMENTAL == 1
-  if (dnn_chunk != NULL)
-    prv_ptr = static_cast<Dtype*>(dnn_chunk->prv_data());
-#endif
-
-  if (this->convert_to_int != NULL) {
-#if MKL_EXPERIMENTAL == 1
-    int status;
-    void *convert_resources[dnnResourceNumber];
-#endif
-    if (prv_ptr == NULL) {
-      this->allocate();
-      this->convert_to_prv(cpu_ptr);
-#if MKL_EXPERIMENTAL == 1
-      if (set_prv_ptr) {
-        dnn_chunk->set_prv_descriptor(this->get_shared_ptr(), true);
-      }
-#endif
-      return this->internal_ptr;
-    }
-#if MKL_EXPERIMENTAL == 1
-    if (prv_ptr != NULL)  {
-      std::shared_ptr<MKLData<Dtype> > current_descr =
-        op::mkl_get_mem_desc<Dtype>(dnn_chunk);
-      if (!dnnLayoutCompare<Dtype>(current_descr->layout_int,
-        this->layout_int)) {
-        if (this->convert_prv2prv) {
-          CHECK_EQ(dnnLayoutCompare<Dtype>(
-            this->descr_prv2prv_conversion->layout_int,
-            this->layout_int), 0);
-          status = 0;
-        } else {
-          status = dnnConversionCreate<Dtype>(&this->convert_prv2prv,
-            current_descr->layout_int, this->layout_int);
-          if (status == 0)
-            this->descr_prv2prv_conversion = current_descr;
-        }
-        if (status != 0) {
-          this->allocate();
-          convert_resources[dnnResourceFrom] = cpu_ptr;
-          convert_resources[dnnResourceTo] =
-            reinterpret_cast<void*>(this->internal_ptr);
-          status = dnnExecute<Dtype>(this->convert_to_int, convert_resources);
-          CHECK_EQ(status, 0) << "Conversion failed with status " << status;
-        } else {
-          this->allocate();
-          convert_resources[dnnResourceFrom] = reinterpret_cast<void*>(prv_ptr);
-          convert_resources[dnnResourceTo] =
-            reinterpret_cast<void*>(this->internal_ptr);
-          status = dnnExecute<Dtype>(this->convert_prv2prv, convert_resources);
-          CHECK_EQ(status, 0) << "Conversion failed with status " << status;
-        }
-        if (set_prv_ptr) {
-          dnn_chunk->set_prv_descriptor(this->get_shared_ptr(), true);
-        }
-        return this->internal_ptr;
-      } else if (current_descr.get() != this) {
-        // MKL_DLOG(INFO) << "layout OK                 "
-        //  << current_descr->name << " == " << this->name;
-      }
-    }
-#endif
-    return const_cast<Dtype *>(prv_ptr);
-  } else {
-    if (prv_ptr != NULL) {
-#if MKL_EXPERIMENTAL == 1
-      std::shared_ptr<MKLMemoryDescriptorBase<float> > other_descr =
-        std::static_pointer_cast<MKLMemoryDescriptorBase<float> >
-        (dnn_chunk->prv_descriptor_);
-      dnn_chunk->check_and_prv_to_cpu(cpu_ptr);
-#endif
-      // printf("get_converted_prv release %s\n", other_descr->name.c_str());
-    }
-  }
-  return cpu_ptr;
-}
-
-template <typename Dtype>
-void* MKLMemoryDescriptor<Dtype>::get_output_ptr(Dtype *data_ptr,
-  std::shared_ptr<MKLMemoryDescriptor<Dtype> > self_ptr, const TBlob &blob, bool in_place) {
-#if MKL_EXPERIMENTAL == 1
-  std::shared_ptr<MKLMemHolder> dnn_chunk = blob.Mkl_mem_;
-#endif
-  if (this->conversion_needed()) {
-    void * prv_ptr =  this->prv_ptr();
-#if MKL_EXPERIMENTAL == 1
-    if (!in_place) {
-      dnn_chunk->set_prv_descriptor(self_ptr);
-    } else {
-      Dtype * blob_prv = op::mkl_prv_data<Dtype>(blob);
-      if (blob_prv != NULL)
-        return blob_prv;
-    }
-#endif
-    return prv_ptr;
-  } else {
-#if MKL_EXPERIMENTAL == 1
-    std::shared_ptr<MKLMemoryDescriptorBase<float> > other_descr =
-      std::static_pointer_cast<MKLMemoryDescriptorBase<float> >
-      (dnn_chunk->prv_descriptor_);
-    dnn_chunk->check_and_prv_to_cpu(data_ptr);
-#endif
-    return data_ptr;
-  }
-}
-
-template class MKLMemoryDescriptor<double>;
-template class MKLMemoryDescriptor<float>;
-
-template class MKLMemoryDescriptorBase<float>;
-template class MKLMemoryDescriptorBase<double>;
-}  // namespace mxnet
-#endif
diff --git a/src/operator/mkl/mkl_memory.h b/src/operator/mkl/mkl_memory.h
deleted file mode 100644
index 13f1fd2..0000000
--- a/src/operator/mkl/mkl_memory.h
+++ /dev/null
@@ -1,123 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_memory.cc
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_MEMORY_H_
-#define MXNET_OPERATOR_MKL_MKL_MEMORY_H_
-
-#include <string>
-#include <vector>
-#include <memory>
-
-
-namespace mxnet {
-// Base class
-struct PrvMemDescr {
-  virtual void convert_from_prv(void* cpu_ptr) = 0;
-  virtual void convert_to_prv(void* cpu_ptr) = 0;
-  virtual void convert_from_other(std::shared_ptr<PrvMemDescr> other) = 0;
-  virtual void* prv_ptr(bool allocate_when_uninit = true) = 0;
-  // returns true for matching layouts
-  virtual bool layout_compare(std::shared_ptr<PrvMemDescr> other) = 0;
-  virtual size_t prv_count() = 0;
-  virtual size_t prv_size() = 0;
-  // This might help using prv_ptr_ by different accelerators/engines
-  enum PrvDescrType {
-    PRV_DESCR_MKL2017,
-    PRV_DESCR_MKLDNN
-  };
-  virtual PrvDescrType get_descr_type() = 0;
-};
-
-#if MKL_EXPERIMENTAL == 1
-// Currently HEAD_AT_PRV do not free CPU data
-enum SyncedHead {
-  HEAD_AT_CPU,
-  HEAD_AT_PRV,
-};
-struct MKLMemHolder {
-  SyncedHead head_;
-  std::shared_ptr<PrvMemDescr> prv_descriptor_;
-  bool  b_disable_prv_2_cpu;
-  bool  b_eager_mode;
-  void disable_prv_2_cpu(bool flag) {
-    b_disable_prv_2_cpu = flag;
-  }
-  void set_eager_mode(bool eager_mode) {
-    b_eager_mode = eager_mode;
-  }
-  void set_prv_descriptor(std::shared_ptr<PrvMemDescr> descriptor, bool same_data = false) {
-    head_ = HEAD_AT_PRV;
-    prv_descriptor_ = descriptor;
-  }
-  std::shared_ptr<PrvMemDescr> get_prv_descriptor() {
-    return  prv_descriptor_;
-  }
-  bool head_at_prv() {
-    return (head_ == HEAD_AT_PRV) ? true : false;
-  }
-  void* prv_data(bool allocate_when_uninit = true) {
-    if (head_ != HEAD_AT_PRV) {
-      return NULL;
-    }
-    if (prv_descriptor_ == NULL) {
-      LOG(FATAL) << " prv_descriptor_  is NULL";
-    }
-    CHECK(prv_descriptor_.get());
-    return reinterpret_cast<void*>(prv_descriptor_->prv_ptr(allocate_when_uninit));
-  }
-
-  int prv_count() {
-    if (head_ != HEAD_AT_PRV) {
-      return 0;
-    }
-    if (prv_descriptor_ == NULL) {
-      LOG(FATAL) << " prv_descriptor_  is NULL";
-    }
-    CHECK(prv_descriptor_.get());
-    return prv_descriptor_->prv_count();
-  }
-  static std::shared_ptr<MKLMemHolder> create() {
-    return std::make_shared<MKLMemHolder>();
-  }
-  void  check_and_prv_to_cpu(void *dptr_) {
-    if (!b_disable_prv_2_cpu && head_ == HEAD_AT_PRV) {
-      CHECK(prv_descriptor_ != nullptr);
-      prv_descriptor_->convert_from_prv(dptr_);
-      // Because operator use CPU & maybe change it, change to CPU Flag
-      head_ = HEAD_AT_CPU;
-    }
-    if (b_disable_prv_2_cpu) {
-      b_disable_prv_2_cpu = false;
-    }
-  }
-  MKLMemHolder() :
-    head_(HEAD_AT_CPU), prv_descriptor_(nullptr),
-    b_disable_prv_2_cpu(false), b_eager_mode(false) {}
-};
-#else
-struct MKLMemHolder {
- public:
-  virtual std::shared_ptr<PrvMemDescr> get_prv_descriptor() = 0;
-};
-#endif
-
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_MEMORY_H_
diff --git a/src/operator/mkl/mkl_pooling-inl.h b/src/operator/mkl/mkl_pooling-inl.h
deleted file mode 100644
index 5662a61..0000000
--- a/src/operator/mkl/mkl_pooling-inl.h
+++ /dev/null
@@ -1,357 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_pooling-inl.h
-* \brief
-* \author zhenlin.luo@intel.com
-*         lingyan.guo@intel.com
-*
-*******************************************************************************/
-
-#ifndef MXNET_OPERATOR_MKL_MKL_POOLING_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_POOLING_INL_H_
-#include <vector>
-#include <string>
-#include <utility>
-#include "../operator_common.h"
-#include "../nn/pooling-inl.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-
-template<typename xpu, typename DType>
-class MKLPoolingOp : public Operator {
- public:
-  static std::string getName() {
-    return "MKLPoolingOp";
-  }
-  explicit MKLPoolingOp(PoolingParam p) {
-    poolingFwd = static_cast<dnnPrimitive_t>(NULL);
-    poolingBwd = static_cast<dnnPrimitive_t>(NULL);
-    max_idx_data = static_cast<DType*>(NULL);
-    fwd_top_data = MKLData<DType>::create();
-    fwd_bottom_data = MKLData<DType>::create();
-    bwd_top_diff = MKLData<DType>::create();
-    bwd_bottom_diff = MKLData<DType>::create();
-    this->param_ = p;
-    init_mkldnn_ = false;
-  }
-  virtual ~MKLPoolingOp() {
-    if (poolingFwd != NULL) {
-      dnnDelete<DType>(poolingFwd);
-      poolingFwd = NULL;
-    }
-    if (poolingBwd != NULL) {
-      dnnDelete<DType>(poolingBwd);
-      poolingBwd = NULL;
-    }
-    if (max_idx_data != NULL) {
-      dnnReleaseBuffer<DType>(max_idx_data);
-      max_idx_data = NULL;
-    }
-  }
-
- private:
-  void LayerSetUp(const mshadow::Tensor<xpu, 4, DType> &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out) {
-    channels_ = data.shape_[1];
-    height_ = data.shape_[2];
-    width_ = data.shape_[3];
-    num_ = data.shape_[0];
-    global_pooling_ = param_.global_pool;
-    if (global_pooling_) {
-      kernel_h_ = height_;
-      kernel_w_ = width_;
-    } else {
-      kernel_h_ = param_.kernel[0];
-      kernel_w_ = param_.kernel[1];
-    }
-    CHECK_GT(kernel_h_, 0) << "Filter dimensions cannot be zero.";
-    CHECK_GT(kernel_w_, 0) << "Filter dimensions cannot be zero.";
-    pad_h_ = param_.pad[0];
-    pad_w_ = param_.pad[1];
-    if (global_pooling_) {
-      stride_h_ = stride_w_ = 1;
-    } else {
-      stride_h_ = param_.stride[0];
-      stride_w_ = param_.stride[1];
-    }
-    if (global_pooling_) {
-      CHECK(pad_h_ == 0 && pad_w_ == 0 && stride_h_ == 1 && stride_w_ == 1)
-        << "With Global_pooling: true; only pad = 0 and stride = 1";
-    }
-    if (pad_h_ != 0 || pad_w_ != 0) {
-      CHECK(param_.pool_type == pool_enum::kAvgPooling
-          || param_.pool_type == pool_enum::kMaxPooling)
-        << "Padding implemented only for average and max pooling.";
-      CHECK_LT(pad_h_, kernel_h_);
-      CHECK_LT(pad_w_, kernel_w_);
-    }
-    pooled_height_ = out.shape_[2];
-    pooled_width_ = out.shape_[3];
-
-    size_t dim = 4;
-    size_t src_sizes[4], src_strides[4];
-    size_t dst_sizes[4], dst_strides[4];
-    src_sizes[0] = width_;
-    src_sizes[1] = height_;
-    src_sizes[2] = channels_;
-    src_sizes[3] = num_;
-    src_strides[0] = 1;
-    src_strides[1] = src_sizes[0];
-    src_strides[2] = src_sizes[0] * src_sizes[1];
-    src_strides[3] = src_sizes[0] * src_sizes[1] * src_sizes[2];
-    dst_sizes[0] = pooled_width_;
-    dst_sizes[1] = pooled_height_;
-    dst_sizes[2] = src_sizes[2];
-    dst_sizes[3] = src_sizes[3];
-    dst_strides[0] = 1;
-    dst_strides[1] = dst_sizes[0];
-    dst_strides[2] = dst_sizes[0] * dst_sizes[1];
-    dst_strides[3] = dst_sizes[0] * dst_sizes[1] * dst_sizes[2];
-    src_offset[0] = -pad_w_;
-    src_offset[1] = -pad_h_;
-    src_offset[2] = -pad_w_;
-    src_offset[3] = -pad_h_;
-    kernel_stride[0] = stride_w_;
-    kernel_stride[1] = stride_h_;
-    kernel_size[0] = kernel_w_;
-    kernel_size[1] = kernel_h_;
-
-    // Names are for debugging only
-    fwd_bottom_data->name = "fwd_bottom_data   @ " + getName();
-    fwd_top_data->name = "fwd_top_data      @ " + getName();
-    bwd_top_diff->name = "bwd_top_diff      @ " + getName();
-    bwd_bottom_diff->name = "bwd_bottom_diff   @ " + getName();
-
-    fwd_bottom_data->create_user_layout(dim, src_sizes, src_strides);
-    fwd_top_data->create_user_layout(dim, dst_sizes, dst_strides);
-    bwd_bottom_diff->create_user_layout(dim, src_sizes, src_strides);
-    bwd_top_diff->create_user_layout(dim, dst_sizes, dst_strides);
-
-    // Primitives will be allocated during the first fwd pass
-    poolingFwd = NULL;
-    poolingBwd = NULL;
-    max_idx_data = NULL;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 1);
-    CHECK_EQ(out_data.size(), 1);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    if (param_.kernel.ndim() >= 3) {
-      LOG(FATAL) << "Not implmented";
-    }
-    Tensor<xpu, 4, DType> data = mkl_experimental_direct_get<xpu, 4, DType>(
-      in_data[pool_enum::kData], s);
-    Tensor<xpu, 4, DType> out = mkl_experimental_direct_get<xpu, 4, DType>(
-      out_data[pool_enum::kOut], s);
-    if (!init_mkldnn_) {
-      LayerSetUp(data, out);
-      init_mkldnn_ = true;
-    }
-    auto first_pass = false;
-    if (poolingFwd == NULL) first_pass = true;
-
-    dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
-
-    switch (param_.pool_type) {
-    case pool_enum::kMaxPooling:
-      algorithm = dnnAlgorithmPoolingMax;
-      break;
-    case pool_enum::kAvgPooling:
-      algorithm = dnnAlgorithmPoolingAvgIncludePadding;
-
-      break;
-    default:
-      LOG(FATAL) << "Unknown pooling method.";
-    }
-
-    dnnError_t status;
-    void* pooling_res[dnnResourceNumber];
-
-    void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data =
-          reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[pool_enum::kData]));
-#endif
-    dnnBorder_t border_type = dnnBorderZerosAsymm;
-    switch (param_.pooling_convention) {
-    case pool_enum::kFull:
-      border_type = dnnBorderZeros;
-      break;
-    case pool_enum::kValid:
-      border_type = dnnBorderZerosAsymm;
-      break;
-    default:
-      border_type = dnnBorderZerosAsymm;
-      break;
-    }
-    if (NULL == bottom_data) {
-      bottom_data = data.dptr_;
-      if (NULL == poolingFwd) {
-        status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
-                                                algorithm, fwd_bottom_data->layout_usr,
-                                                kernel_size, kernel_stride,
-                                                src_offset, border_type);
-      CHECK_EQ(status, E_SUCCESS);
-      // Now create poolingBwd
-      status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
-                                               algorithm, fwd_bottom_data->layout_usr,
-                                               kernel_size, kernel_stride,
-                                               src_offset, border_type);
-      CHECK_EQ(status, E_SUCCESS);
-      }
-    }
-#if MKL_EXPERIMENTAL == 1
-    if (NULL != bottom_data) {
-       if (NULL == poolingFwd) {
-          std::shared_ptr<MKLMemHolder> bottom_data_mem = in_data[pool_enum::kData].Mkl_mem_;
-          std::shared_ptr<PrvMemDescr> bottom_prv_descriptor =
-            bottom_data_mem->get_prv_descriptor();
-          CHECK_EQ(bottom_prv_descriptor->get_descr_type(),
-                   PrvMemDescr::PRV_DESCR_MKL2017);
-          std::shared_ptr<MKLData<DType> > mem_descr
-            = std::static_pointer_cast<MKLData<DType>>(bottom_prv_descriptor);
-          CHECK(mem_descr != nullptr);
-          fwd_bottom_data = mem_descr;
-
-          status = dnnPoolingCreateForward<DType>(&poolingFwd, NULL,
-                                                  algorithm, fwd_bottom_data->layout_int,
-                                                  kernel_size, kernel_stride,
-                                                  src_offset, border_type);
-          CHECK_EQ(status, E_SUCCESS);
-          fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);
-
-          // Now create poolingBwd
-          status = dnnPoolingCreateBackward<DType>(&poolingBwd, NULL,
-                                                   algorithm, fwd_bottom_data->layout_int,
-                                                   kernel_size, kernel_stride,
-                                                   src_offset, border_type);
-          CHECK_EQ(status, E_SUCCESS);
-          bwd_top_diff->create_internal_layout(poolingFwd, dnnResourceDst);
-          bwd_bottom_diff->create_internal_layout(poolingFwd, dnnResourceSrc);
-        }
-    }
-#endif
-
-    if (first_pass) {
-      dnnLayout_t max_idx_datal = NULL;
-      status = dnnLayoutCreateFromPrimitive<DType>(
-          &max_idx_datal, poolingFwd, dnnResourceWorkspace);
-      CHECK_EQ(status, E_SUCCESS);
-      status = dnnAllocateBuffer<DType>(reinterpret_cast<void**>(&max_idx_data), max_idx_datal);
-      CHECK_EQ(status, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-      fwd_bottom_data->create_internal_layout(poolingFwd, dnnResourceSrc);
-      fwd_top_data->create_internal_layout(poolingFwd, dnnResourceDst);
-      bwd_top_diff->create_internal_layout(poolingBwd, dnnResourceDiffDst);
-      bwd_bottom_diff->create_internal_layout(poolingBwd, dnnResourceDiffSrc);
-#endif
-      dnnLayoutDelete<DType>(max_idx_datal);
-      first_pass = false;
-    }
-    pooling_res[dnnResourceSrc] = bottom_data;
-    pooling_res[dnnResourceWorkspace] = max_idx_data;
-
-    pooling_res[dnnResourceDst] = fwd_top_data->get_output_ptr(
-      out.dptr_, fwd_top_data, out_data[pool_enum::kOut]);
-    status = dnnExecute<DType>(poolingFwd, pooling_res);
-    CHECK_EQ(status, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-    if (fwd_top_data->conversion_needed()) {
-      fwd_top_data->convert_from_prv(out.dptr_);
-    }
-#endif
-  }
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    if (!req[0]) {
-      return;
-    }
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1);
-    CHECK_EQ(in_data.size(), 1);
-    CHECK_EQ(out_data.size(), 1);
-    CHECK_EQ(req.size(), 1);
-    CHECK_EQ(in_grad.size(), 1);
-    if (param_.kernel.ndim() >= 3) {
-      LOG(FATAL) << "Not implmented";
-    }
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> grad = mkl_experimental_direct_get<xpu, 4, DType>(
-      out_grad[pool_enum::kOut], s);
-    Tensor<xpu, 4, DType> input_grad = mkl_experimental_direct_get<xpu, 4, DType>(
-      in_grad[pool_enum::kData], s);
-    dnnError_t e;
-    void* pooling_res[dnnResourceNumber];
-    pooling_res[dnnResourceWorkspace] = reinterpret_cast<void *>(max_idx_data);
-
-    pooling_res[dnnResourceDiffDst] =
-      bwd_top_diff->get_converted_prv(grad.dptr_, true, out_grad[pool_enum::kOut]);
-
-    pooling_res[dnnResourceDiffSrc] = bwd_bottom_diff->get_output_ptr(
-      input_grad.dptr_, bwd_bottom_diff, in_grad[pool_enum::kData]);
-    e = dnnExecute<DType>(poolingBwd, pooling_res);
-    CHECK_EQ(e, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-    if (bwd_bottom_diff->conversion_needed()) {
-      bwd_bottom_diff->convert_from_prv(input_grad.dptr_);
-    }
-#endif
-  }
-
- private:
-  PoolingParam param_;
-  int kernel_h_, kernel_w_;
-  int stride_h_, stride_w_;
-  int pad_h_, pad_w_;
-  int channels_, num_;
-  int height_, width_;
-  int pooled_height_, pooled_width_;
-  bool global_pooling_;
-
- private:
-  size_t kernel_size[2],
-         kernel_stride[4];
-  int src_offset[4];  // 2*(dimension-2)
-  dnnPrimitive_t poolingFwd, poolingBwd;
-  DType *max_idx_data;
-
-  std::shared_ptr<MKLData<DType> > fwd_top_data;
-  std::shared_ptr<MKLData<DType> > fwd_bottom_data;
-  std::shared_ptr<MKLData<DType> > bwd_top_diff;
-  std::shared_ptr<MKLData<DType> > bwd_bottom_diff;
-  bool init_mkldnn_;
-};  // class MKLPoolingOp
-}   // namespace op
-}   // namespace mxnet
-
-#endif  // MXNET_OPERATOR_MKL_MKL_POOLING_INL_H_
diff --git a/src/operator/mkl/mkl_relu-inl.h b/src/operator/mkl/mkl_relu-inl.h
deleted file mode 100644
index 8d7ab5e..0000000
--- a/src/operator/mkl/mkl_relu-inl.h
+++ /dev/null
@@ -1,272 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_relu-inl.h
-* \brief
-* \author zhenlin.luo@intel.com
-*         lingyan.guo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_RELU_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_RELU_INL_H_
-
-
-#include <dmlc/logging.h>
-#include <dmlc/parameter.h>
-#include <mxnet/operator.h>
-#include <algorithm>
-#include <map>
-#include <vector>
-#include <string>
-#include <utility>
-#include "../operator_common.h"
-#include "./mkl_util-inl.h"
-
-namespace mxnet {
-namespace op {
-
-template<typename xpu, typename DType>
-class MKLReluOp : public Operator {
- public:
-  static std::string getName() {
-    return "MKLReluOp";
-  }
-  MKLReluOp():
-      reluFwd_(NULL),
-      reluBwd_(NULL) {
-    init_mkldnn_ = false;
-    fwd_top_data_ = MKLData<DType>::create();
-    fwd_bottom_data_ = MKLData<DType>::create();
-    bwd_top_diff_ = MKLData<DType>::create();
-    bwd_bottom_diff_ = MKLData<DType>::create();
-  }
-
-  ~MKLReluOp() {
-    if (reluFwd_ != NULL) {
-      dnnDelete<DType>(reluFwd_);
-      reluFwd_ = NULL;
-    }
-    if (reluBwd_ != NULL) {
-      dnnDelete<DType>(reluBwd_);
-      reluBwd_ = NULL;
-    }
-  }
-
- private:
-  void LayerSetUp(const mshadow::Tensor<xpu, 4, DType> &data,
-                  const mshadow::Tensor<xpu, 4, DType> &out) {
-    size_t dim = 4;
-    size_t *sizes = new size_t[dim];
-    size_t *strides = new size_t[dim];
-    for (size_t d = 0; d < dim; ++d) {
-      (sizes)[d] = data.shape_[dim - 1 - d];
-      (strides)[d] = (d == 0) ? 1 : (strides)[d - 1] * (sizes)[d - 1];
-    }
-    // Names are for debugging only
-    fwd_bottom_data_->name = "fwd_bottom_data   @ " + getName();
-    fwd_top_data_->name = "fwd_top_data      @ " + getName();
-    bwd_bottom_diff_->name = "bwd_bottom_diff   @ " + getName();
-    bwd_top_diff_->name = "bwd_top_diff      @ " + getName();
-    fwd_bottom_data_->create_user_layout(dim, (sizes), (strides));
-    fwd_top_data_->create_user_layout(dim, (sizes), (strides));
-    bwd_bottom_diff_->create_user_layout(dim, (sizes), (strides));
-    bwd_top_diff_->create_user_layout(dim, (sizes), (strides));
-    delete[] sizes;
-    delete[] strides;
-  }
-
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 1);
-    CHECK_EQ(out_data.size(), 1);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> data;
-    Tensor<xpu, 4, DType> out;
-    if (in_data[activation::kData].ndim() == 1) {
-      Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0], 1, 1, 1);
-      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_data[activation::kData], dshape, s);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-    } else if (in_data[activation::kData].ndim() == 2) {
-      Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0],
-      in_data[activation::kData].shape_[1], 1, 1);
-      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_data[activation::kData], dshape, s);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-    } else if (in_data[activation::kData].ndim() == 3) {
-      Shape<4> dshape = Shape4(in_data[activation::kData].shape_[0],
-        in_data[activation::kData].shape_[1],
-        in_data[activation::kData].shape_[2], 1);
-      data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_data[activation::kData], dshape, s);
-      out = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-    } else {
-      data = mkl_experimental_direct_get<xpu, 4, DType>(in_data[activation::kData], s);
-      out = mkl_experimental_direct_get<xpu, 4, DType>(out_data[activation::kOut], s);
-    }
-    if (!init_mkldnn_) {
-      LayerSetUp(data, out);
-      init_mkldnn_ = true;
-    }
-    void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data =
-          reinterpret_cast<void *>(mkl_prv_data<DType>(in_data[activation::kData]));
-#endif
-#if MKL_EXPERIMENTAL == 1
-    if (bottom_data != NULL) {
-      if (reluFwd_ == NULL) {
-      std::shared_ptr<MKLData<DType> > mem_descr =
-        mkl_get_mem_desc<DType>(in_data[activation::kData].Mkl_mem_);
-      DType negative_slope = 0;
-      dnnError_t e;
-      e = dnnReLUCreateForward<DType>(&reluFwd_, NULL, mem_descr->layout_int,
-                                      negative_slope);
-      CHECK_EQ(e, E_SUCCESS);
-      e = dnnReLUCreateBackward<DType>(&reluBwd_, NULL, mem_descr->layout_int,
-                                       mem_descr->layout_int, negative_slope);
-      CHECK_EQ(e, E_SUCCESS);
-
-      fwd_bottom_data_ = mem_descr;
-      fwd_top_data_->create_internal_layout(reluFwd_, dnnResourceDst);
-      bwd_top_diff_->create_internal_layout(reluFwd_, dnnResourceDst);
-      bwd_bottom_diff_->create_internal_layout(reluFwd_, dnnResourceSrc);
-      }
-    }
-#endif
-    if (bottom_data  == NULL) {
-      bottom_data = data.dptr_;
-      if (reluFwd_ == NULL) {
-        dnnError_t e;
-        DType negative_slope = 0;
-        e = dnnReLUCreateForward<DType>(&reluFwd_, NULL,
-                                        fwd_bottom_data_->layout_usr, negative_slope);
-        CHECK_EQ(e, E_SUCCESS);
-        e = dnnReLUCreateBackward<DType>(&reluBwd_, NULL,
-                                         fwd_bottom_data_->layout_usr, fwd_bottom_data_->layout_usr,
-                                         negative_slope);
-        CHECK_EQ(e, E_SUCCESS);
-      }
-    }
-    dnnError_t e;
-    void* relu_res[dnnResourceNumber];
-    relu_res[dnnResourceSrc] = bottom_data;
-
-    relu_res[dnnResourceDst] = fwd_top_data_->get_output_ptr(
-      out.dptr_, fwd_top_data_, out_data[activation::kOut], (data.dptr_ == out.dptr_));
-    e = dnnExecute<DType>(reluFwd_, relu_res);
-    CHECK_EQ(e, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-    if (fwd_top_data_->conversion_needed()) {
-      fwd_top_data_->convert_from_prv(out.dptr_);
-    }
-#endif
-  }
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    if (!req[0]) {
-      return;
-    }
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1);
-    CHECK(in_data.size() == 1 && in_grad.size() == 1);
-    CHECK_EQ(req.size(), 1);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    Tensor<xpu, 4, DType> m_out_grad;
-    Tensor<xpu, 4, DType> m_out_data;
-    Tensor<xpu, 4, DType> m_in_grad;
-
-    if (out_grad[activation::kOut].ndim() == 1) {
-      Shape<4> dshape = Shape4(out_grad[activation::kOut].shape_[0], 1, 1, 1);
-      m_out_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[activation::kOut], dshape, s);
-      m_out_data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-      m_in_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_grad[activation::kData], dshape, s);
-    } else if (out_grad[activation::kOut].ndim() == 2) {
-      Shape<4> dshape = Shape4(out_grad[activation::kOut].shape_[0],
-                               out_grad[activation::kOut].shape_[1], 1, 1);
-      m_out_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[activation::kOut], dshape, s);
-      m_out_data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-      m_in_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_grad[activation::kData], dshape, s);
-    } else if (out_grad[activation::kOut].ndim() == 3) {
-      Shape<4> dshape = Shape4(out_grad[activation::kOut].shape_[0],
-        out_grad[activation::kOut].shape_[1],
-        out_grad[activation::kOut].shape_[2], 1);
-      m_out_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_grad[activation::kOut], dshape, s);
-      m_out_data = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        out_data[activation::kOut], dshape, s);
-      m_in_grad = mkl_experimental_direct_get_with_shape<xpu, 4, DType>(
-        in_grad[activation::kData], dshape, s);
-    } else {
-      m_out_grad = mkl_experimental_direct_get<xpu, 4, DType>(out_grad[activation::kOut], s);
-      m_out_data = mkl_experimental_direct_get<xpu, 4, DType>(out_data[activation::kOut], s);
-      m_in_grad = mkl_experimental_direct_get<xpu, 4, DType>(in_grad[activation::kData], s);
-    }
-    dnnError_t e;
-    void* relu_res[dnnResourceNumber];
-
-    void* bottom_data = NULL;
-#if MKL_EXPERIMENTAL == 1
-    bottom_data = reinterpret_cast<void *>(mkl_prv_data<DType>(out_data[activation::kOut]));
-#endif
-    if (NULL == bottom_data) {
-      bottom_data = reinterpret_cast<void *>(const_cast<DType*>(m_out_data.dptr_));
-    }
-    relu_res[dnnResourceSrc] = bottom_data;
-    relu_res[dnnResourceDiffDst] = bwd_top_diff_->get_converted_prv(m_out_grad.dptr_,
-                true, out_grad[activation::kOut]);
-    relu_res[dnnResourceDiffSrc] = bwd_bottom_diff_->get_output_ptr(
-      m_in_grad.dptr_, bwd_bottom_diff_, in_grad[activation::kData]);
-    e = dnnExecute<DType>(reluBwd_, relu_res);
-    CHECK_EQ(e, E_SUCCESS);
-#if MKL_EXPERIMENTAL == 0
-    if (bwd_bottom_diff_->conversion_needed()) {
-      bwd_bottom_diff_->convert_from_prv(m_in_grad.dptr_);
-    }
-#endif
-  }
-
- private:
-  bool init_mkldnn_;
-  std::shared_ptr<MKLData<DType> > fwd_top_data_;
-  std::shared_ptr<MKLData<DType> > fwd_bottom_data_;
-  std::shared_ptr<MKLData<DType> > bwd_top_diff_;
-  std::shared_ptr<MKLData<DType> > bwd_bottom_diff_;
-  dnnPrimitive_t reluFwd_, reluBwd_;
-};  // class MKLReluOp
-}  // namespace op
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_RELU_INL_H_
diff --git a/src/operator/mkl/mkl_util-inl.h b/src/operator/mkl/mkl_util-inl.h
deleted file mode 100644
index 4ad786a..0000000
--- a/src/operator/mkl/mkl_util-inl.h
+++ /dev/null
@@ -1,110 +0,0 @@
-/*******************************************************************************
-* Copyright 2016 Intel Corporation
-*
-* Licensed under the Apache License, Version 2.0 (the "License");
-* you may not use this file except in compliance with the License.
-* You may obtain a copy of the License at
-*
-*     http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*
-* \file mkl_util-inl.h
-* \brief
-* \author lingyan.guo@intel.com
-*         zhenlin.luo@intel.com
-*
-*******************************************************************************/
-#ifndef MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_
-#define MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_
-#include <vector>
-#define MKLDNN_CALL(func)                                                               \
-  {                                                                                     \
-    dnnError_t status = (func);                                                                \
-    CHECK_EQ(status, E_SUCCESS) << "MKL DNN call failed (status: " << status << ").";           \
-  }
-
-
-namespace mxnet {
-namespace op {
-
-#if MKL_EXPERIMENTAL == 1
-  template<typename DType>
-  inline DType * mkl_prv_data(const TBlob &b) {
-    std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
-    bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
-    if (mem_valid) {
-      return reinterpret_cast<DType*>(bottom_data_mem->prv_data());
-    }
-    return NULL;
-  }
-
-  template<typename DType>
-  inline int mkl_prv_count(const TBlob &b) {
-    std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
-    bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
-    if (mem_valid) {
-      return bottom_data_mem->prv_count();
-    }
-    return 0;
-  }
-#endif
-  inline void mkl_set_priv_flag(const TBlob &b) {
-#if MKL_EXPERIMENTAL == 1
-    std::shared_ptr<MKLMemHolder> bottom_data_mem = b.Mkl_mem_;
-    bool mem_valid = (bottom_data_mem != nullptr) && bottom_data_mem->head_at_prv();
-    if (mem_valid) {
-      bottom_data_mem->disable_prv_2_cpu(true);
-    }
-#endif
-  }
-#if MKL_EXPERIMENTAL == 1
-  template<typename DType>
-  inline std::shared_ptr<MKLData<DType> > mkl_get_mem_desc(
-    const std::shared_ptr<MKLMemHolder> data_mem) {
-    std::shared_ptr<PrvMemDescr> prv_descriptor =
-      data_mem->get_prv_descriptor();
-    CHECK_EQ(prv_descriptor->get_descr_type(),
-      PrvMemDescr::PRV_DESCR_MKL2017);
-    std::shared_ptr<MKLData<DType> > mem_descr
-      = std::static_pointer_cast<MKLData<DType>>
-      (prv_descriptor);
-    CHECK(mem_descr != NULL);
-    return mem_descr;
-  }
-#endif
-  template<typename xpu, int dim, typename DType>
-  inline  mshadow::Tensor<xpu, dim, DType> mkl_experimental_direct_get(
-    const TBlob &b, mshadow::Stream<xpu> *s) {
-    mkl_set_priv_flag(b);
-    return b.get<xpu, dim, DType>(s);
-  }
-  template<typename xpu, int dim, typename DType>
-  inline  mshadow::Tensor<xpu, dim, DType> mkl_experimental_direct_get_with_shape(
-    const TBlob &b, const mshadow::Shape<dim> &shape, mshadow::Stream<xpu> *s) {
-    mkl_set_priv_flag(b);
-    return b.get_with_shape<xpu, dim, DType>(shape, s);
-  }
-}  // namespace op
-#if MKL_EXPERIMENTAL == 1
-inline void mkl_tblobs_prv_to_cpu(const std::vector<TBlob> &data) {
-  for (size_t i = 0; i < data.size(); i++) {
-    std::shared_ptr<MKLMemHolder> mem_holder = data[i].Mkl_mem_;
-    if (mem_holder != nullptr && mem_holder->b_eager_mode) {
-      mem_holder->check_and_prv_to_cpu(data[i].dptr_);
-    }
-  }
-}
-inline void mkl_set_tblob_eager_mode(const TBlob &data) {
-  std::shared_ptr<MKLMemHolder> mem_holder = data.Mkl_mem_;
-  if (mem_holder != nullptr) {
-    mem_holder->set_eager_mode(true);
-  }
-}
-#endif
-}  // namespace mxnet
-#endif  // MXNET_OPERATOR_MKL_MKL_UTIL_INL_H_
diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h
index ac8b747..a440f97 100644
--- a/src/operator/nn/activation-inl.h
+++ b/src/operator/nn/activation-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2015 by Contributors
  * \file activation-inl.h
  * \brief Activation operator
- * \author Bing Xu
+ * \author Bing Xu, Da Zheng
 */
 
 #ifndef MXNET_OPERATOR_NN_ACTIVATION_INL_H_
@@ -37,6 +37,7 @@
 #include <utility>
 #include "../operator_common.h"
 #include "../mxnet_op.h"
+#include "../mshadow_op.h"
 
 namespace mxnet {
 namespace op {
@@ -45,6 +46,7 @@ namespace op {
 namespace activation {
 enum ActivationOpInputs {kData};
 enum ActivationOpOutputs {kOut};
+enum ActivationOpResource {kTempSpace};
 enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU};
 }  // activation
 
@@ -59,160 +61,148 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
     .add_enum("softrelu", activation::kSoftReLU)
     .describe("Activation function to be applied.");
   }
-};
 
-/**
- * \brief This is the implementation of activation operator.
- * \tparam xpu The device that the op will be executed on.
- */
-template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
-class ActivationOp : public Operator {
- public:
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(in_data.size(), 1U);
-    CHECK_EQ(out_data.size(), 1U);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    const TBlob& input = in_data[activation::kData];
-    const size_t sz = input.shape_.Size();
-    if (sz) {
-      MXNET_ASSIGN_REQ_SWITCH(req[activation::kOut], Req, {
-        mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch(
-          s, sz,
-          out_data[activation::kOut].dptr<DType>(),
-          input.dptr<DType>());
-      });
-    }
+  bool operator==(const ActivationParam& other) const {
+    return this->act_type == other.act_type;
   }
+};
 
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_args) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-    CHECK_EQ(out_grad.size(), 1U);
-    CHECK(in_data.size() == 1 && in_grad.size() == 1);
-    CHECK_EQ(req.size(), 1U);
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    const TBlob& m_out_grad = out_grad[activation::kOut];
-    const TBlob& m_out_data = out_data[activation::kOut];
-    const TBlob&  m_in_grad = in_grad[activation::kData];
-    const size_t sz = m_out_data.shape_.Size();
-    if (sz) {
-      MXNET_ASSIGN_REQ_SWITCH(req[activation::kData], Req, {
-        mxnet_op::Kernel<mxnet_op::op_with_req<
-          mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch(
-          s, sz,
-          m_in_grad.dptr<DType>(),
-          m_out_grad.dptr<DType>(),
-          m_out_data.dptr<DType>());
-      });
-    }
-  }
-};  // class ActivationOp
-
-// Declare Factory function, used for dispatch specialization
-template<typename xpu>
-Operator* CreateOp(ActivationParam type, int dtype, const TShape& dshape);
+}  // namespace op
+}  // namespace mxnet
 
-#if DMLC_USE_CXX11
-class ActivationProp : public OperatorProperty {
- public:
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
+namespace std {
+template<>
+struct hash<mxnet::op::ActivationParam> {
+  size_t operator()(const mxnet::op::ActivationParam& val) {
+    return val.act_type;
   }
+};
+}  // namespace std
+
+namespace mxnet {
+namespace op {
 
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
+template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
+void ActivationForward(const OpContext &ctx, const TBlob &in_data,
+                       const OpReqType &req, const TBlob &out_data) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const size_t sz = in_data.shape_.Size();
+  if (sz) {
+    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+      mxnet_op::Kernel<mxnet_op::op_with_req<ForwardOp, Req>, xpu>::Launch(
+        s, sz,
+        out_data.dptr<DType>(),
+        in_data.dptr<DType>());
+    });
   }
+}
 
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
-    const TShape &dshape = in_shape->at(activation::kData);
-    if (dshape.ndim() == 0) return false;
-    out_shape->clear();
-    out_shape->push_back(dshape);
-    return true;
+template<typename xpu, typename ForwardOp, typename BackwardOp, typename DType>
+void ActivationBackward(const OpContext &ctx, const TBlob &out_grad,
+                        const TBlob &out_data, const OpReqType &req,
+                        const TBlob &in_grad) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const size_t sz = out_data.shape_.Size();
+  if (sz) {
+    MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+      mxnet_op::Kernel<mxnet_op::op_with_req<
+        mxnet::op::mxnet_op::backward_grad_tuned<BackwardOp>, Req>, xpu>::Launch(
+        s, sz,
+        in_grad.dptr<DType>(),
+        out_grad.dptr<DType>(),
+        out_data.dptr<DType>());
+    });
   }
+}
 
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    CHECK_GE(in_type->size(), 1U);
-    int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    for (index_t i = 0; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-          (*in_type)[i] = dtype;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
-      }
+template<typename xpu>
+void ActivationComputeImpl(const ActivationParam &param, const OpContext &ctx,
+                           const TBlob &input, OpReqType req, const TBlob &output) {
+  MSHADOW_REAL_TYPE_SWITCH(input.type_flag_, DType, {
+    switch (param.act_type) {
+      case activation::kReLU:
+        ActivationForward<xpu, mshadow_op::relu, mshadow_op::relu_grad, DType>(
+            ctx, input, req, output);
+        break;
+      case activation::kSigmoid:
+        ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>(
+            ctx, input, req, output);
+        break;
+      case activation::kTanh:
+        ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>(
+            ctx, input, req, output);
+        break;
+      case activation::kSoftReLU:
+        ActivationForward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
+            ctx, input, req, output);
+        break;
+      default:
+        LOG(FATAL) << "unknown activation type";
     }
-    out_type->clear();
-    out_type->push_back(dtype);
-    return true;
-  }
+  });
+}
 
-  OperatorProperty* Copy() const override {
-    auto ptr = new ActivationProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
+template<typename xpu>
+void ActivationGradComputeImpl(const ActivationParam &param, const OpContext &ctx,
+                               const TBlob &out_grad, const TBlob &out_data,
+                               OpReqType req, const TBlob &output) {
+  MSHADOW_REAL_TYPE_SWITCH(out_grad.type_flag_, DType, {
+    switch (param.act_type) {
+      case activation::kReLU:
+        ActivationBackward<xpu, mshadow_op::relu, mshadow_op::relu_grad, DType>(
+            ctx, out_grad, out_data, req, output);
+        break;
+      case activation::kSigmoid:
+        ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>(
+            ctx, out_grad, out_data, req, output);
+        break;
+      case activation::kTanh:
+        ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>(
+            ctx, out_grad, out_data, req, output);
+        break;
+      case activation::kSoftReLU:
+        ActivationBackward<xpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
+            ctx, out_grad, out_data, req, output);
+        break;
+      default:
+        LOG(FATAL) << "unknown activation type";
+    }
+  });
+}
 
-  std::string TypeString() const override {
-    return "Activation";
-  }
+template<typename xpu>
+void ActivationCompute(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx,
+    const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  ActivationComputeImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
+}
 
-  // decalre dependency and inplace optimization options
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
+template<typename xpu>
+void ActivationGradCompute(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx,
+    const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
 #if MXNET_USE_CUDNN == 1
-    return {out_grad[activation::kOut], out_data[activation::kOut], in_data[activation::kData]};
+  CHECK_EQ(inputs.size(), 3U);
 #else
-    return {out_grad[activation::kOut], out_data[activation::kOut]};
-#endif  // MXNET_USE_CUDNN
-  }
+  CHECK_EQ(inputs.size(), 2U);
+#endif
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  ActivationGradComputeImpl<xpu>(param, ctx, inputs[0], inputs[1], req[0], outputs[0]);
+}
 
-  std::vector<std::pair<int, void*> > BackwardInplaceOption(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data,
-    const std::vector<void*> &in_grad) const override {
-    return {{out_grad[activation::kOut], in_grad[activation::kData]}};
-  }
-
-  std::vector<std::pair<int, void*> > ForwardInplaceOption(
-    const std::vector<int> &in_data,
-    const std::vector<void*> &out_data) const override {
-    return {{in_data[activation::kData], out_data[activation::kOut]}};
-  }
-
-  Operator* CreateOperator(Context ctx) const override {
-    LOG(FATAL) << "Not Implemented.";
-    return NULL;
-  }
-
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                             std::vector<int> *in_type) const override;
-
- private:
-  ActivationParam param_;
-};
-#endif  // DMLC_USE_CXX11
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_NN_ACTIVATION_INL_H_
diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 401a9e3..0da644c 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -17,69 +17,130 @@
  * under the License.
  */
 
+
 /*!
  * Copyright (c) 2015 by Contributors
  * \file activation.cc
  * \brief activation op
- * \author Bing Xu
+ * \author Bing Xu, Da Zheng
 */
 #include "./activation-inl.h"
 #include "../mshadow_op.h"
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "../mkl/mkl_memory-inl.h"
-#include "../mkl/mkl_relu-inl.h"
-#endif  // MXNET_USE_MKL2017
+#include "../tensor/elemwise_unary_op.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_base-inl.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
+#endif  // MXNET_USE_MKLDNN
 
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(ActivationParam param, int dtype, const TShape& dshape) {
-  Operator *op = NULL;
-#if MXNET_USE_MKL2017 == 1
-  if (param.act_type == activation::kReLU && dshape.ndim() <= 4) {
-      switch (dtype) {
-      case mshadow::kFloat32:
-          return new MKLReluOp<cpu, float>();
-      case mshadow::kFloat64:
-          return new MKLReluOp<cpu, double>();
-      default:
-          break;
-      }
+
+DMLC_REGISTER_PARAMETER(ActivationParam);
+
+// This will determine the order of the inputs for backward computation.
+struct ActivationGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& ograds) const {
+    std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
+    heads.emplace_back(nnvm::NodeEntry{n, activation::kOut, 0});
+#if MXNET_USE_CUDNN == 1
+    heads.push_back(n->inputs[activation::kData]);
+#endif
+    return MakeGradNode(op_name, n, heads, n->attrs.dict);
+  }
+};
+
+#if MXNET_USE_MKLDNN == 1
+static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  if (SupportMKLDNN(inputs[0])) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNActivationForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+    MKLDNN_OPCHECK_RUN(ActivationCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
   }
-  if (enableMKLWarnGenerated())
-    LOG(INFO) << MKLReluOp<cpu, float>::getName() << " Skip MKL optimization";
+  ActivationComputeImpl<cpu>(param, ctx, inputs[0].data(), req[0], outputs[0].data());
+}
+
+void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<NDArray>& outputs) {
+#if MXNET_USE_CUDNN == 1
+  CHECK_EQ(inputs.size(), 3U);
+#else
+  CHECK_EQ(inputs.size(), 2U);
 #endif
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    switch (param.act_type) {
-      case activation::kReLU:
-        op = new ActivationOp<cpu, mshadow_op::relu, mshadow_op::relu_grad, DType>();
-        break;
-      case activation::kSigmoid:
-        op = new ActivationOp<cpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>();
-        break;
-      case activation::kTanh:
-        op = new ActivationOp<cpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>();
-        break;
-      case activation::kSoftReLU:
-        op = new ActivationOp<cpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>();
-        break;
-      default:
-        LOG(FATAL) << "unknown activation type";
-    }
-  })
-  return op;
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  if (SupportMKLDNN(inputs[0])) {
+    MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+    MKLDNNActivationBackward(attrs, ctx, inputs[0], inputs[1], req[0],
+                             outputs[0]);
+      MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  }
+  ActivationGradComputeImpl<cpu>(param, ctx, inputs[0].data(), inputs[1].data(),
+                                 req[0], outputs[0].data());
 }
+#endif
 
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *ActivationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                                           std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]);
+inline static bool ActivationStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int> *in_attrs,
+                                         std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 1);
+  CHECK_EQ(out_attrs->size(), 1);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+  bool ret = ElemwiseStorageType<1, 1, false, false, false>(attrs, dev_mask,
+                                                            dispatch_mode,
+                                                            in_attrs, out_attrs);
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  return ret;
 }
 
-DMLC_REGISTER_PARAMETER(ActivationParam);
+inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
+                                          DispatchMode* dispatch_mode,
+                                          std::vector<int> *in_attrs,
+                                          std::vector<int> *out_attrs) {
+#if MXNET_USE_CUDNN == 1
+  CHECK_EQ(in_attrs->size(), 3U);
+#else
+  CHECK_EQ(in_attrs->size(), 2U);
+#endif
+  CHECK_EQ(out_attrs->size(), 1U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+#if MXNET_USE_CUDNN == 1
+  bool ret = ElemwiseStorageType<3, 1, false, false, false>(attrs, dev_mask,
+                                                            dispatch_mode,
+                                                            in_attrs, out_attrs);
+#else
+  bool ret = ElemwiseStorageType<2, 1, false, false, false>(attrs, dev_mask,
+                                                            dispatch_mode,
+                                                            in_attrs, out_attrs);
+#endif
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask && SupportMKLDNNAct(param)) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+#endif
+  return ret;
+}
 
-MXNET_REGISTER_OP_PROPERTY(Activation, ActivationProp)
+MXNET_OPERATOR_REGISTER_UNARY(Activation)
 .describe(R"code(Applies an activation function element-wise to the input.
 
 The following activation functions are supported:
@@ -90,8 +151,35 @@ The following activation functions are supported:
 - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
 
 )code" ADD_FILELINE)
-.add_argument("data", "NDArray-or-Symbol", "Input array to activation function.")
+.set_attr_parser(ParamParser<ActivationParam>)
+.set_attr<FInferStorageType>("FInferStorageType", ActivationStorageType)
+.set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
+#endif
+.set_attr<nnvm::FGradient>("FGradient", ActivationGrad{"_backward_Activation"})
 .add_arguments(ActivationParam::__FIELDS__());
 
+NNVM_REGISTER_OP(_backward_Activation)
+.set_num_inputs(3)
+.set_num_outputs(1)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
+.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<3, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
+  return std::vector<std::pair<int, int> >{{0, 0}};
+})
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
+.set_attr_parser(ParamParser<ActivationParam>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
+#endif
+.set_attr<FCompute>("FCompute<cpu>", ActivationGradCompute<cpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu
index c2f6be9..dc435b2 100644
--- a/src/operator/nn/activation.cu
+++ b/src/operator/nn/activation.cu
@@ -31,39 +31,73 @@
 
 namespace mxnet {
 namespace op {
+
+#if MXNET_USE_CUDNN == 1
+
+template<typename DType>
+static CuDNNActivationOp<DType> &get_cudnn_op(const ActivationParam& param) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local CuDNNActivationOp<DType> cudnn_op;
+#else
+  static MX_THREAD_LOCAL CuDNNActivationOp<DType> cudnn_op;
+#endif
+  cudnn_op.Init(param);
+  return cudnn_op;
+}
+
 template<>
-Operator *CreateOp<gpu>(ActivationParam param, int dtype, const TShape& dshape) {
-  Operator *op = NULL;
+void ActivationCompute<gpu>(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx,
+    const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+
   // SoftReLU not supported by CUDNN yet
   if (param.act_type == activation::kSoftReLU) {
-    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-      op = new ActivationOp<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>();
-    })
-    return op;
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      ActivationForward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(ctx,
+          inputs[0], req[0], outputs[0]);
+    });
+  } else {
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      get_cudnn_op<DType>(param).Forward(ctx, inputs[0], req[0], outputs[0]);
+    });
   }
+}
 
-#if MXNET_USE_CUDNN == 1
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new CuDNNActivationOp<DType>(param);
-  })
-#else
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    switch (param.act_type) {
-      case activation::kReLU:
-        op = new ActivationOp<gpu, mshadow_op::relu, mshadow_op::relu_grad, DType>();
-        break;
-      case activation::kSigmoid:
-        op = new ActivationOp<gpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad, DType>();
-        break;
-      case activation::kTanh:
-        op = new ActivationOp<gpu, mshadow_op::tanh, mshadow_op::tanh_grad, DType>();
-        break;
-      default:
-        LOG(FATAL) << "unknown activation";
-    }
-  })
-#endif  // MXNET_USE_CUDNN
-  return op;
+template<>
+void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
+    const OpContext& ctx,
+    const std::vector<TBlob>& inputs,
+    const std::vector<OpReqType>& req,
+    const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  const ActivationParam& param = nnvm::get<ActivationParam>(attrs.parsed);
+
+  // SoftReLU not supported by CUDNN yet
+  if (param.act_type == activation::kSoftReLU) {
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad, DType>(
+          ctx, inputs[0], inputs[1], req[0], outputs[0]);
+    });
+  } else {
+    MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[2], inputs[1], req[0], outputs[0]);
+    });
+  }
 }
+#endif
+
+NNVM_REGISTER_OP(Activation)
+.set_attr<FCompute>("FCompute<gpu>", ActivationCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_Activation)
+.set_attr<FCompute>("FCompute<gpu>", ActivationGradCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 2a9dee2..27e0a84 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -21,7 +21,7 @@
  * Copyright (c) 2017 by Contributors
  * \file batch_norm-inl.h
  * \brief
- * \author Bing Xu, Chris Olivier
+ * \author Bing Xu, Chris Olivier, Da Zheng
  */
 #ifndef MXNET_OPERATOR_NN_BATCH_NORM_INL_H_
 #define MXNET_OPERATOR_NN_BATCH_NORM_INL_H_
@@ -47,8 +47,10 @@ namespace mxnet {
 namespace op {
 
 namespace batchnorm {
-enum BatchNormOpInputs {kData, kGamma, kBeta};  // kGamma: weights, kBeta: biases
+enum BatchNormOpInputs {kData, kGamma, kBeta, kInMovingMean,
+  kInMovingVar};  // kGamma: weights, kBeta: biases
 enum BatchNormOpOutputs {kOut, kMean, kVar};  // req, out_data
+enum BatchNormOpResource {kTempSpace};
 enum BatchNormOpAuxiliary {kMovingMean, kMovingVar};  // aux_states
 
 /*! \brief Default channel axis if none specified int he params */
@@ -83,280 +85,203 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
     DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
       .describe("Do not select CUDNN operator, if available");
   }
-};
-
-/*! \brief Batch normalization operator */
-template <typename xpu, typename DType, typename AccReal>
-class BatchNormOp : public Operator {
- public:
-  explicit BatchNormOp(BatchNormParam param) {
-    this->param_ = param;
-  }
-
-  static inline bool IsWriting(const OpReqType ort) {
-    return ort == kWriteTo || ort == kWriteInplace;
-  }
-
-  /*!
-   * \brief perform a forward operation of Operator, save the output to TBlob.
-   * \param ctx runtime context available to this call
-   * \param in_data array of input data, it is const
-   * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
-   * \param out_data array of output data, pointer is used to indicate that this is holder
-   *        the space of TBlob in out_data must be pre-allocated with InferShape
-   * \param aux_states Auxiliary states of operator. Normally operator doesn't
-   *        need, epecial case like Batch Norm requires.
-   * \sa OpReqType, OpContext
-   */
-  virtual void Forward(const OpContext &ctx,
-                       const std::vector<TBlob> &in_data,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &out_data,
-                       const std::vector<TBlob> &aux_states) {
-    using namespace mshadow;
-    using namespace mshadow::expr;
-
-    CHECK_EQ(in_data.size(), 3U);
-    CHECK_EQ(aux_states.size(), 2U);
-    if (ctx.is_train) {
-      CHECK_EQ(out_data.size(), 3U);
-      CHECK_EQ(req.size(), 3U);
-    } else {
-      CHECK_GE(out_data.size(), 1U);
-      CHECK_GE(req.size(), 1U);
-      CHECK_EQ(req[batchnorm::kOut], kWriteTo);
-    }
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    DoForward(s, ctx, in_data, req, out_data, aux_states);
-  }
-
-  /*!
-   * \brief Perform a Backward Operation, write gradient to the in_grad.
-   *
-   * \note
-   * Convention:
-   *   out_grad.size() == OperatorProperty.NumVisibleOutputs()
-   *   out_data.size() == OperatorProperty.NumOutputs()
-   * out_data can contain additional invisible returns that remembers the
-   * state carried from the Forward pass. For example mask in the dropout.
-   * The gradients are passed from visible returns in this function.
-   *
-   * \par
-   * Not all the TBlobs in the arguments will be available
-   * if you override the DeclareBackwardDependency of corresponding OperatorProperty class.
-   * Only the dependencies you declared will be available at corresponding position,
-   * the rest of the parameters are simply dummy where you will get a nullptr.
-   * You will be safe if you use the default DeclareBackwardDependency.
-   * But only declare what you need will give engine more chance for optimization.
-   *
-   * \param ctx runtime context available to this call
-   * \param out_grad the gradient value we get from of the Operator.
-   * \param in_data the array of input data.
-   * \param out_data the array of output data.
-   * \param req request types of the saving operation, can be all types.
-   * \param in_grad the array of gradient we need to write to.
-   * \param aux_states Auxiliary states of operator. Normally operator doesn't need
-   * \sa OperatorProperty, OpReqType, OpContext
-   */
-  virtual void Backward(const OpContext &ctx,
-                        const std::vector<TBlob> &out_grad,
-                        const std::vector<TBlob> &in_data,
-                        const std::vector<TBlob> &out_data,
-                        const std::vector<OpReqType> &req,
-                        const std::vector<TBlob> &in_grad,
-                        const std::vector<TBlob> &aux_states) {
-    CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U);
-    CHECK_EQ(in_data.size(), 3U);
-    CHECK_EQ(out_data.size(), 3U);
-    CHECK_EQ(in_grad.size(), 3U);
-    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    DoBackward(s, ctx, out_grad, in_data,
-               out_data, req, in_grad, aux_states);
-  }
-
- private:
-  void DoForward(mshadow::Stream<cpu> *stream,
-                 const OpContext &ctx,
-                 const std::vector<TBlob> &in_data,
-                 const std::vector<OpReqType> &req,
-                 const std::vector<TBlob> &out_data,
-                 const std::vector<TBlob> &aux_states);
-
-  void DoBackward(mshadow::Stream<cpu> *stream,
-                  const OpContext &ctx,
-                  const std::vector<TBlob> &out_grad,
-                  const std::vector<TBlob> &in_data,
-                  const std::vector<TBlob> &out_data,
-                  const std::vector<OpReqType> &req,
-                  const std::vector<TBlob> &in_grad,
-                  const std::vector<TBlob> &aux_states);
-
-#if MXNET_USE_CUDA
-  void DoForward(mshadow::Stream<gpu> *stream,
-                 const OpContext &ctx,
-                 const std::vector<TBlob> &in_data,
-                 const std::vector<OpReqType> &req,
-                 const std::vector<TBlob> &out_data,
-                 const std::vector<TBlob> &aux_states);
-  void DoBackward(mshadow::Stream<gpu> *stream,
-                  const OpContext &ctx,
-                  const std::vector<TBlob> &out_grad,
-                  const std::vector<TBlob> &in_data,
-                  const std::vector<TBlob> &out_data,
-                  const std::vector<OpReqType> &req,
-                  const std::vector<TBlob> &in_grad,
-                  const std::vector<TBlob> &aux_states);
-#endif  // MXNET_USE_CUDA
-
-  /*! \brief Batch normalization operator parameters */
-  BatchNormParam param_;
-};  // class BatchNormOp
 
-template<typename xpu>
-Operator *CreateOp(BatchNormParam param, const int dtype, const TShape& shape);
-
-#if DMLC_USE_CXX11
-class BatchNormProp : public OperatorProperty {
- public:
-  void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
-    param_.Init(kwargs);
-  }
-
-  std::map<std::string, std::string> GetParams() const override {
-    return param_.__DICT__();
-  }
-
-  bool InferShape(std::vector<TShape> *in_shape,
-                  std::vector<TShape> *out_shape,
-                  std::vector<TShape> *aux_shape) const override {
-    using namespace mshadow;
-    CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
-    const TShape &dshape = in_shape->at(0);
-
-    const size_t channelAxis = static_cast<size_t>(param_.axis < 0
-                            ? static_cast<int>(dshape.ndim()) + param_.axis
-                            : param_.axis);
-    CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param_.axis;
-
-    const int channelCount = dshape[channelAxis];
-
-    if (dshape.ndim() == 0) {
-      return false;
-    }
-
-    in_shape->at(1) = TShape(Shape1(channelCount));
-    in_shape->at(2) = TShape(Shape1(channelCount));
-
-    out_shape->clear();
-    out_shape->push_back(dshape);                // kOut
-    out_shape->push_back(Shape1(channelCount));  // kMean
-    out_shape->push_back(Shape1(channelCount));  // kVar
-
-    aux_shape->clear();
-    aux_shape->push_back(Shape1(channelCount));  // kMovingMean
-    aux_shape->push_back(Shape1(channelCount));  // kMovingVar
-    return true;
-  }
-
-  bool InferType(std::vector<int> *in_type,
-                 std::vector<int> *out_type,
-                 std::vector<int> *aux_type) const override {
-    using namespace mshadow;
-    CHECK_GE(in_type->size(), 1U);
-    const int dtype = (*in_type)[0];
-    CHECK_NE(dtype, -1) << "First input must have specified type";
-    // For float16 input type beta, gamma, mean, and average are stored in float32.
-    // For other input types, these parameters have the same type as input
-    // NOTE: This requirement is from cuDNN (v. 4 and 5)
-    int dtype_param;
-    MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
-         dtype_param = mshadow::DataType<AccRealX>::kFlag; });
-    for (index_t i = 1; i < in_type->size(); ++i) {
-      if ((*in_type)[i] == -1) {
-        (*in_type)[i] = dtype_param;
-      } else {
-        UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]);
-      }
-    }
-    for (index_t i = 0; i < aux_type->size(); ++i) {
-      if ((*aux_type)[i] != -1) {
-        UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]);
-      }
-    }
-    const size_t n_aux = this->ListAuxiliaryStates().size();
-    aux_type->clear();
-    for (size_t i = 0; i < n_aux; ++i) {
-      aux_type->push_back(dtype_param);
-    }
-    const size_t n_out = this->ListOutputs().size();
-    out_type->clear();
-    out_type->push_back(dtype);
-    for (size_t i = 1; i < n_out; ++i) {
-      out_type->push_back(dtype_param);
-    }
-    return true;
+  bool operator==(const BatchNormParam& other) const {
+    return this->eps == other.eps &&
+           this->momentum == other.momentum &&
+           this->fix_gamma == other.fix_gamma &&
+           this->use_global_stats == other.use_global_stats &&
+           this->output_mean_var == other.output_mean_var &&
+           this->axis == other.axis &&
+           this->cudnn_off == other.cudnn_off;
   }
+};
 
-  OperatorProperty* Copy() const override {
-    auto ptr = new BatchNormProp();
-    ptr->param_ = param_;
-    return ptr;
-  }
+}  // namespace op
+}  // namespace mxnet
 
-  std::string TypeString() const override {
-    return "BatchNorm";
+namespace std {
+template<>
+struct hash<mxnet::op::BatchNormParam> {
+  size_t operator()(const mxnet::op::BatchNormParam& val) {
+    size_t ret = 0;
+    ret = dmlc::HashCombine(ret, val.momentum);
+    ret = dmlc::HashCombine(ret, val.fix_gamma);
+    ret = dmlc::HashCombine(ret, val.use_global_stats);
+    ret = dmlc::HashCombine(ret, val.output_mean_var);
+    ret = dmlc::HashCombine(ret, val.axis);
+    return ret;
   }
+};
+}  // namespace std
 
-  std::vector<int> DeclareBackwardDependency(
-    const std::vector<int> &out_grad,
-    const std::vector<int> &in_data,
-    const std::vector<int> &out_data) const override {
-    return {out_grad[batchnorm::kOut],
-            out_data[batchnorm::kMean],
-            out_data[batchnorm::kVar],
-            in_data[batchnorm::kData],
-            in_data[batchnorm::kGamma]
-           };
-  }
+namespace mxnet {
+namespace op {
 
-  int NumVisibleOutputs() const override {
-    if (param_.output_mean_var) {
-      return 3;
-    }
-    return 1;
-  }
+static inline bool IsBNWriting(const OpReqType ort) {
+  return ort == kWriteTo || ort == kWriteInplace;
+}
 
-  int NumOutputs() const override {
-    return 3;
-  }
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormForwardImpl(mshadow::Stream<cpu> *stream,
+                          const OpContext &ctx, const BatchNormParam& param,
+                          const std::vector<TBlob> &in_data,
+                          const std::vector<OpReqType> &req,
+                          const std::vector<TBlob> &out_data,
+                          const std::vector<TBlob> &aux_states);
 
-  std::vector<std::string> ListArguments() const override {
-    return {"data", "gamma", "beta"};
-  }
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormBackwardImpl(mshadow::Stream<cpu> *stream,
+                           const OpContext &ctx, const BatchNormParam& param,
+                           const std::vector<TBlob> &out_grad,
+                           const std::vector<TBlob> &in_data,
+                           const std::vector<TBlob> &out_data,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<TBlob> &in_grad,
+                           const std::vector<TBlob> &aux_states);
 
-  std::vector<std::string> ListOutputs() const override {
-    return {"output", "mean", "var"};
-  }
+#if MXNET_USE_CUDA
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormForwardImpl(mshadow::Stream<gpu> *stream,
+                          const OpContext &ctx, const BatchNormParam& param,
+                          const std::vector<TBlob> &in_data,
+                          const std::vector<OpReqType> &req,
+                          const std::vector<TBlob> &out_data,
+                          const std::vector<TBlob> &aux_states);
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormBackwardImpl(mshadow::Stream<gpu> *stream,
+                           const OpContext &ctx, const BatchNormParam& param,
+                           const std::vector<TBlob> &out_grad,
+                           const std::vector<TBlob> &in_data,
+                           const std::vector<TBlob> &out_data,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<TBlob> &in_grad,
+                           const std::vector<TBlob> &aux_states);
+#endif  // MXNET_USE_CUDA
 
-  std::vector<std::string> ListAuxiliaryStates() const override {
-    return {"moving_mean", "moving_var"};
-  }
+/*!
+ * \brief perform a forward operation of Operator, save the output to TBlob.
+ * \param ctx runtime context available to this call
+ * \param in_data array of input data, it is const
+ * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace.
+ * \param out_data array of output data, pointer is used to indicate that this is holder
+ *        the space of TBlob in out_data must be pre-allocated with InferShape
+ * \param aux_states Auxiliary states of operator. Normally operator doesn't
+ *        need, epecial case like Batch Norm requires.
+ * \sa OpReqType, OpContext
+ */
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormForward(const OpContext &ctx, const BatchNormParam& param,
+                      const std::vector<TBlob> &in_data,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &out_data,
+                      const std::vector<TBlob> &aux_states) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+
+  CHECK_EQ(in_data.size(), 3U);
+  CHECK_EQ(aux_states.size(), 2U);
+  if (ctx.is_train) {
+    CHECK_EQ(out_data.size(), 3U);
+    CHECK_EQ(req.size(), 3U);
+  } else {
+    CHECK_GE(out_data.size(), 1U);
+    CHECK_GE(req.size(), 1U);
+    CHECK_EQ(req[batchnorm::kOut], kWriteTo);
+  }
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  BatchNormForwardImpl<xpu, DType, AccReal>(s, ctx, param, in_data, req,
+                                            out_data, aux_states);
+}
 
-  Operator* CreateOperator(Context ctx) const override {
-      LOG(FATAL) << "Not Implemented.";
-      return NULL;
-  }
+/*!
+ * \brief Perform a Backward Operation, write gradient to the in_grad.
+ *
+ * \note
+ * Convention:
+ *   out_grad.size() == OperatorProperty.NumVisibleOutputs()
+ *   out_data.size() == OperatorProperty.NumOutputs()
+ * out_data can contain additional invisible returns that remembers the
+ * state carried from the Forward pass. For example mask in the dropout.
+ * The gradients are passed from visible returns in this function.
+ *
+ * \par
+ * Not all the TBlobs in the arguments will be available
+ * if you override the DeclareBackwardDependency of corresponding OperatorProperty class.
+ * Only the dependencies you declared will be available at corresponding position,
+ * the rest of the parameters are simply dummy where you will get a nullptr.
+ * You will be safe if you use the default DeclareBackwardDependency.
+ * But only declare what you need will give engine more chance for optimization.
+ *
+ * \param ctx runtime context available to this call
+ * \param out_grad the gradient value we get from of the Operator.
+ * \param in_data the array of input data.
+ * \param out_data the array of output data.
+ * \param req request types of the saving operation, can be all types.
+ * \param in_grad the array of gradient we need to write to.
+ * \param aux_states Auxiliary states of operator. Normally operator doesn't need
+ * \sa OperatorProperty, OpReqType, OpContext
+ */
+template <typename xpu, typename DType, typename AccReal>
+void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
+                       const std::vector<TBlob> &out_grad,
+                       const std::vector<TBlob> &in_data,
+                       const std::vector<TBlob> &out_data,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &in_grad,
+                       const std::vector<TBlob> &aux_states) {
+  CHECK_EQ(out_grad.size(), param.output_mean_var ? 3U : 1U);
+  CHECK_EQ(in_data.size(), 3U);
+  CHECK_EQ(out_data.size(), 3U);
+  CHECK_EQ(in_grad.size(), 3U);
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  BatchNormBackwardImpl<xpu, DType, AccReal>(s, ctx, param, out_grad, in_data,
+                                             out_data, req, in_grad, aux_states);
+}
 
-  Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-      std::vector<int> *in_type) const override;
+template<typename xpu>
+void BatchNormCompute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx, const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 5U);
+  std::vector<TBlob> in_data(inputs.begin(),
+                             inputs.begin() + batchnorm::kInMovingMean);
+  std::vector<TBlob> aux_states(inputs.begin() + batchnorm::kInMovingMean,
+                                inputs.end());
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    BatchNormForward<xpu, DType, AccReal>(ctx, param, in_data, req, outputs,
+                                          aux_states);
+  });
+}
 
-  inline const BatchNormParam& getParam() const {
-    return param_;
-  }
+template<typename xpu>
+void BatchNormGradCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx, const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 11U);
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  int num_out_grads = param.output_mean_var ? 3U : 1U;
+  int in_data_start = 3;
+  int aux_states_start = in_data_start + batchnorm::kInMovingMean;
+  int out_data_start = in_data_start + batchnorm::kInMovingVar + 1;
+  std::vector<TBlob> out_grad(inputs.begin(), inputs.begin() + num_out_grads);
+  std::vector<TBlob> in_data(inputs.begin() + in_data_start,
+                             inputs.begin() + aux_states_start);
+  std::vector<TBlob> aux_states(inputs.begin() + aux_states_start,
+                                inputs.begin() + out_data_start);
+  std::vector<TBlob> out_data(inputs.begin() + out_data_start, inputs.end());
+  std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
+
+  MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, {
+    BatchNormBackward<xpu, DType, AccReal>(ctx, param, out_grad, in_data, out_data, req,
+                                           in_grad, aux_states);
+  });
+}
 
- private:
-  BatchNormParam param_;
-};  // class BatchNormProp
+#if DMLC_USE_CXX11
 
 namespace batchnorm {
 
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index ca28832..ba6c413 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -21,16 +21,15 @@
  * Copyright (c) 2015 by Contributors
  * \file batch_norm.cc
  * \brief
- * \author Bing Xu, Chris Olivier
+ * \author Bing Xu, Chris Olivier, Da Zheng
 */
 
 #include "batch_norm-inl.h"
 #include <nnvm/op_attr_types.h>
-#if MXNET_USE_MKL2017 == 1
-#include <mkl_memory.h>
-#include "../mkl/mkl_memory-inl.h"
-#include "../mkl/mkl_batch_norm-inl.h"
-#endif  // MXNET_USE_MKL2017
+#include "../elemwise_op_common.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_batch_norm-inl.h"
+#endif
 
 /*! \brief inverse standard deviation <-> variance */
 #define VARIANCE_TO_INVSTD(__var$,    __eps$)   (1.0/sqrt((__var$) + DType(__eps$)))
@@ -89,12 +88,12 @@ static inline void ForEachFast(const BNTensor3<DType1> &in_data,
 
 /*! \brief Forward CPU */
 template <typename xpu, typename DType, typename AccReal>
-void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<cpu> *,
-                                                 const OpContext &ctx,
-                                                 const std::vector<TBlob> &in_data,
-                                                 const std::vector<OpReqType> &req,
-                                                 const std::vector<TBlob> &out_data,
-                                                 const std::vector<TBlob> &aux_states) {
+void BatchNormForwardImpl(mshadow::Stream<cpu> *,
+                          const OpContext &ctx, const BatchNormParam& param_,
+                          const std::vector<TBlob> &in_data,
+                          const std::vector<OpReqType> &req,
+                          const std::vector<TBlob> &out_data,
+                          const std::vector<TBlob> &aux_states) {
   // Input
   batchnorm::BNTensor3<DType> inputData(in_data[batchnorm::kData], param_.axis);
   const TBlob &weights         = in_data[batchnorm::kGamma];
@@ -164,7 +163,7 @@ void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<cpu> *,
 
     // note that var is still invstd
     if (!param_.fix_gamma) {
-      if (IsWriting(req[batchnorm::kData])) {
+      if (IsBNWriting(req[batchnorm::kData])) {
         ForEachFast(inputData, outputData, channel,
                     [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
                                                                  DType *out_data) {
@@ -173,10 +172,10 @@ void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<cpu> *,
                     });
       }
     } else {
-      if (IsWriting(req[batchnorm::kGamma])) {
+      if (IsBNWriting(req[batchnorm::kGamma])) {
         w[channel] = AccReal(1);
       }
-      if (IsWriting(req[batchnorm::kData])) {
+      if (IsBNWriting(req[batchnorm::kData])) {
         ForEachFast(inputData, outputData, channel,
                     [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
                                                                  DType *out_data) {
@@ -189,14 +188,14 @@ void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<cpu> *,
 }
 
 template <typename xpu, typename DType, typename AccReal>
-void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
-                                                  const OpContext &ctx,
-                                                  const std::vector<TBlob> &out_grad,
-                                                  const std::vector<TBlob> &in_data,
-                                                  const std::vector<TBlob> &out_data,
-                                                  const std::vector<OpReqType> &req,
-                                                  const std::vector<TBlob> &in_grad,
-                                                  const std::vector<TBlob> &aux_states) {
+void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
+                           const OpContext &ctx, const BatchNormParam& param_,
+                           const std::vector<TBlob> &out_grad,
+                           const std::vector<TBlob> &in_data,
+                           const std::vector<TBlob> &out_data,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<TBlob> &in_grad,
+                           const std::vector<TBlob> &aux_states) {
   // Input Data
   batchnorm::BNTensor3<DType> inputData(in_data[batchnorm::kData], param_.axis);
   const TBlob &weights   = in_data[batchnorm::kGamma];
@@ -264,7 +263,7 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
                   dotp += (*thisInputData - mean) * (*gradOut_data);
                 });
 
-    if (!gradIn.IsEmpty() && IsWriting(req[batchnorm::kData])) {  // if there's a grad input
+    if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) {  // if there's a grad input
       if (is_train_and_not_global_stats) {
         // when in training mode
         // Q(X) = X - E[x] ; i.e. input centered to zero mean
@@ -300,7 +299,7 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
     // May want to make this a param eventually
     const AccReal scale = 1.0f;
 
-    if (IsWriting(req[batchnorm::kGamma])) {
+    if (IsBNWriting(req[batchnorm::kGamma])) {
       if (!param_.fix_gamma) {
         gradWeightData[channel] = scale * dotp * invstd;
       } else {
@@ -308,51 +307,185 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<cpu> *,
       }
     }
 
-    if (IsWriting(req[batchnorm::kBeta])) {
+    if (IsBNWriting(req[batchnorm::kBeta])) {
       gradBiasData[channel] = scale * sumGradOut;
     }
   }
 }
 
-template<>
-Operator *CreateOp<cpu>(BatchNormParam param, const int dtype, const TShape& shape) {
-  param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
-  Operator *op = nullptr;
-#if MXNET_USE_MKL2017 == 1
-  if (shape.ndim() == 4
+DMLC_REGISTER_PARAMETER(BatchNormParam);
+
+static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_shape,
+                           std::vector<TShape> *out_shape) {
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, MovingVar]";
+  const TShape &dshape = in_shape->at(batchnorm::kData);
+
+  const size_t channelAxis = static_cast<size_t>(param.axis < 0
+      ? static_cast<int>(dshape.ndim()) + param.axis
+      : param.axis);
+  CHECK_LT(channelAxis, dshape.ndim()) << "Channel axis out of range: " << param.axis;
+
+  const int channelCount = dshape[channelAxis];
+
+  if (dshape.ndim() == 0) {
+    return false;
+  }
+
+  in_shape->at(batchnorm::kGamma) = TShape(Shape1(channelCount));
+  in_shape->at(batchnorm::kBeta) = TShape(Shape1(channelCount));
+  in_shape->at(batchnorm::kInMovingMean) = TShape(Shape1(channelCount));  // kMovingMean
+  in_shape->at(batchnorm::kInMovingVar) = TShape(Shape1(channelCount));  // kMovingVar
+
+  out_shape->clear();
+  out_shape->push_back(dshape);                // kOut
+  out_shape->push_back(Shape1(channelCount));  // kMean
+  out_shape->push_back(Shape1(channelCount));  // kVar
+
+  return true;
+}
+
+static bool BatchNormType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int> *in_type, std::vector<int> *out_type) {
+  using namespace mshadow;
+  CHECK_GE(in_type->size(), 1U);
+  const int dtype = (*in_type)[0];
+  CHECK_NE(dtype, -1) << "First input must have specified type";
+  // For float16 input type beta, gamma, mean, and average are stored in float32.
+  // For other input types, these parameters have the same type as input
+  // NOTE: This requirement is from cuDNN (v. 4 and 5)
+  int dtype_param;
+  MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+      dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+  std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
+  CHECK_LE(in_type->size(), args.size());
+  for (index_t i = 1; i < in_type->size(); ++i) {
+    if ((*in_type)[i] == -1) {
+      (*in_type)[i] = dtype_param;
+    } else {
+      UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
+    }
+  }
+  const size_t n_out = 3;
+  out_type->clear();
+  out_type->push_back(dtype);
+  for (size_t i = 1; i < n_out; ++i) {
+    out_type->push_back(dtype_param);
+  }
+  return true;
+}
+
+#if MXNET_USE_MKLDNN == 1
+static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
+  TShape shape = input.shape();
+  return SupportMKLDNN(input) && shape.ndim() == 4
       && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
-      && !mxnet::op::batchnorm::disable_mkl) {
-    switch (dtype) {
-      case mshadow::kFloat32:
-        op = new MKLBatchNormOp<cpu, float>(param);
-        break;
-      case mshadow::kFloat64:
-        op = new MKLBatchNormOp<cpu, double>(param);
-        break;
-      default:
-        // MKL operator doesn't support half_t, so fall through
-        break;
+      && shape[param.axis] % 8 == 0;
+}
+
+void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
+                           const OpContext &ctx,
+                           const std::vector<NDArray> &inputs,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 5U);
+  const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
+  // MKLDNN batchnorm only works well on the special MKLDNN layout.
+  if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {
+    std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
+    std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
+
+    if (inputs[0].dtype() == mshadow::kFloat32) {
+      MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+      MKLDNNBatchNormForward<float>(ctx, param, in_data, req, outputs, aux_states);
+      MKLDNN_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+      return;
     }
   }
-#endif
-  if (!op) {
-    MSHADOW_REAL_TYPE_SWITCH_EX(dtype,
-                                DType,
-                                AccReal, {
-                                  op = new BatchNormOp<cpu, DType, AccReal>(param); });
+  FallBackCompute(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
+                               const OpContext &ctx,
+                               const std::vector<NDArray> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<NDArray> &outputs) {
+  CHECK_EQ(inputs.size(), 11U);
+  const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
+  int num_out_grads = param.output_mean_var ? 3U : 1U;
+  int in_data_start = 3;
+  int aux_states_start = in_data_start + batchnorm::kInMovingMean;
+  int out_data_start = in_data_start + batchnorm::kInMovingVar + 1;
+
+  TShape shape = inputs[0].shape();
+  // MKLDNN batchnorm only works well on the special MKLDNN layout.
+  if (SupportMKLDNNBN(inputs[0], param)
+      && (inputs[in_data_start].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
+    std::vector<NDArray> out_grad(inputs.begin(), inputs.begin() + num_out_grads);
+    std::vector<NDArray> in_data(inputs.begin() + in_data_start,
+                                 inputs.begin() + aux_states_start);
+    std::vector<NDArray> aux_states(inputs.begin() + aux_states_start,
+                                    inputs.begin() + out_data_start);
+    std::vector<NDArray> out_data(inputs.begin() + out_data_start, inputs.end());
+    std::vector<NDArray> in_grad(outputs.begin(), outputs.begin() + 3);
+
+    if (inputs[0].dtype() == mshadow::kFloat32) {
+      MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+      MKLDNNBatchNormBackward<float>(ctx, param, out_grad, in_data,
+                                     out_data, req, in_grad, aux_states);
+      MKLDNN_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+      return;
+    }
   }
-  return op;
+  FallBackCompute(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
 }
+#endif
 
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator *BatchNormProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
-                                          std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0], (*in_shape)[0]);
+static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
+                                        const int dev_mask,
+                                        DispatchMode *dispatch_mode,
+                                        std::vector<int> *in_attrs,
+                                        std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 5);
+  CHECK_EQ(out_attrs->size(), 3);
+  DispatchMode wanted_mode;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask)
+    wanted_mode = DispatchMode::kFComputeEx;
+  else
+#endif
+    wanted_mode = DispatchMode::kFCompute;
+  for (int& v : *in_attrs) {
+    if (v == - 1) v = kDefaultStorage;
+  }
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
 }
 
-DMLC_REGISTER_PARAMETER(BatchNormParam);
+static inline bool backward_BatchNormStorageType(const nnvm::NodeAttrs &attrs,
+                                                 const int dev_mask,
+                                                 DispatchMode *dispatch_mode,
+                                                 std::vector<int> *in_attrs,
+                                                 std::vector<int> *out_attrs) {
+  CHECK_EQ(in_attrs->size(), 11);
+  CHECK_EQ(out_attrs->size(), 5);
+  DispatchMode wanted_mode;
+#if MXNET_USE_MKLDNN == 1
+  if (dev_mask == mshadow::cpu::kDevMask)
+    wanted_mode = DispatchMode::kFComputeEx;
+  else
+#endif
+    wanted_mode = DispatchMode::kFCompute;
+  for (int& v : *in_attrs) {
+    if (v == - 1) v = kDefaultStorage;
+  }
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
 
-MXNET_REGISTER_OP_PROPERTY(BatchNorm, BatchNormProp)
+NNVM_REGISTER_OP(BatchNorm)
 .describe(R"code(Batch normalization.
 
 Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as
@@ -398,14 +531,44 @@ Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is tr
 then set ``gamma`` to 1 and its gradient to 0.
 
 )code" ADD_FILELINE)
+.set_num_inputs(5)
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<BatchNormParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "gamma", "beta", "moving_mean", "moving_var"};
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output", "mean", "var"};
+})
+.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
+    [](const NodeAttrs& attrs) {
+  const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
+  return param.output_mean_var ? 3 : 1;
+})
+.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) {
+  return std::vector<uint32_t>{3, 4};
+})
+.set_attr<nnvm::FInferShape>("FInferShape", BatchNormShape)
+.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
+.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
+.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
+#endif
+.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_BatchNorm"})
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
 .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization")
 .add_argument("gamma", "NDArray-or-Symbol", "gamma array")
 .add_argument("beta", "NDArray-or-Symbol", "beta array")
 .add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input")
 .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input")
-.add_arguments(BatchNormParam::__FIELDS__());
-
-NNVM_REGISTER_OP(BatchNorm)
+.add_arguments(BatchNormParam::__FIELDS__())
 .set_attr<nnvm::FSetInputVarAttrOnCompose>(
   "FSetInputVarAttrOnCompose",
   [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) {
@@ -417,5 +580,20 @@ NNVM_REGISTER_OP(BatchNorm)
     }
   });
 
+NNVM_REGISTER_OP(_backward_BatchNorm)
+.set_num_outputs(5)
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FInferStorageType>("FInferStorageType", backward_BatchNormStorageType)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+#endif
+.set_attr_parser(ParamParser<BatchNormParam>)
+#if MXNET_USE_MKLDNN == 1
+.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
+#endif
+.set_attr<FCompute>("FCompute<cpu>", BatchNormGradCompute<cpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 59317b7..80c1597 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -21,7 +21,7 @@
  * Copyright (c) 2017 by Contributors
  * \file batch_norm.cu
  * \brief CUDA Batch Normalization code
- * \author Chris Olivier, Bing Xu
+ * \author Chris Olivier, Bing Xu, Da Zheng
  * Adapted from Torch
 */
 #include <cuda_runtime_api.h>
@@ -579,13 +579,13 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
   flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
   flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
   flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
-  if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kData])) {
+  if (IsBNWriting(req[batchnorm::kData])) {
     flags |= WRITE_DATA_FLAG;
   }
-  if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kGamma])) {
+  if (IsBNWriting(req[batchnorm::kGamma])) {
     flags |= WRITE_GAMMA_FLAG;
   }
-  if (BatchNormOp<xpu, DType, AccReal>::IsWriting(req[batchnorm::kBeta])) {
+  if (IsBNWriting(req[batchnorm::kBeta])) {
     flags |= WRITE_BETA_FLAG;
   }
   return flags;
@@ -593,12 +593,12 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
 
 /*! \brief Forward batch-norm pass on GPU */
 template<typename xpu, typename DType, typename AccReal>
-void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<gpu> *stream,
-                                                 const OpContext &ctx,
-                                                 const std::vector<TBlob> &in_data,
-                                                 const std::vector<OpReqType> &req,
-                                                 const std::vector<TBlob> &out_data,
-                                                 const std::vector<TBlob> &aux_states) {
+void BatchNormForwardImpl(mshadow::Stream<gpu> *stream,
+                          const OpContext &ctx, const BatchNormParam& param_,
+                          const std::vector<TBlob> &in_data,
+                          const std::vector<OpReqType> &req,
+                          const std::vector<TBlob> &out_data,
+                          const std::vector<TBlob> &aux_states) {
   batchnorm::cuda::BatchNormalizationUpdateOutput<DType, AccReal>(
     stream,
     ctx,
@@ -614,14 +614,14 @@ void BatchNormOp<xpu, DType, AccReal>::DoForward(mshadow::Stream<gpu> *stream,
 
 /*! \brief Backward batch-norm pass on GPU */
 template<typename xpu, typename DType, typename AccReal>
-void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<gpu> *stream,
-                                                  const OpContext &ctx,
-                                                  const std::vector<TBlob> &out_grad,
-                                                  const std::vector<TBlob> &in_data,
-                                                  const std::vector<TBlob> &out_data,
-                                                  const std::vector<OpReqType> &req,
-                                                  const std::vector<TBlob> &in_grad,
-                                                  const std::vector<TBlob> &aux_states) {
+void BatchNormBackwardImpl(mshadow::Stream<gpu> *stream,
+                           const OpContext &ctx, const BatchNormParam& param_,
+                           const std::vector<TBlob> &out_grad,
+                           const std::vector<TBlob> &in_data,
+                           const std::vector<TBlob> &out_data,
+                           const std::vector<OpReqType> &req,
+                           const std::vector<TBlob> &in_grad,
+                           const std::vector<TBlob> &aux_states) {
   batchnorm::cuda::BatchNormalizationBackward<DType, AccReal>(
     stream,
     ctx,
@@ -637,30 +637,92 @@ void BatchNormOp<xpu, DType, AccReal>::DoBackward(mshadow::Stream<gpu> *stream,
   MSHADOW_CUDA_POST_KERNEL_CHECK(BatchNormOp_DoBackward_gpu);
 }
 
-/*! \brief Create GPU operator for batch normalization */
+#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 4
+template<typename DType>
+static CuDNNBatchNormOp<DType> &GetCuDNNOp(const BatchNormParam& param) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local CuDNNBatchNormOp<DType> op;
+#else
+  static MX_THREAD_LOCAL CuDNNBatchNormOp<DType> op;
+#endif
+  op.Init(param);
+  return op;
+}
+#endif
+
 template<>
-Operator *CreateOp<gpu>(BatchNormParam param, const int dtype, const TShape& shape) {
+void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,
+                           const OpContext& ctx, const std::vector<TBlob>& inputs,
+                           const std::vector<OpReqType>& req,
+                           const std::vector<TBlob>& outputs) {
+  BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed);
+  CHECK_EQ(inputs.size(), 5U);
+  std::vector<TBlob> in_data(inputs.begin(), inputs.begin() + 3);
+  std::vector<TBlob> aux_states(inputs.begin() + 3, inputs.end());
+  int dtype = inputs[0].type_flag_;
+  TShape shape = inputs[0].shape_;
+
   param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
-  Operator *op = NULL;
 #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
   if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
       && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
     MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-      op = new CuDNNBatchNormOp<DType>(param);
+      GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
     })
   } else {
     MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
-      op = new BatchNormOp<gpu, DType, AccReal>(param);
+      BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
     })
   }
 #else
-  MSHADOW_REAL_TYPE_SWITCH_EX(dtype,
-                              DType,
-                              AccReal,
-                              { op = new BatchNormOp<gpu, DType, AccReal>(param); });
+  MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
+    BatchNormForward<gpu, DType, AccReal>(ctx, param, in_data, req, outputs, aux_states);
+  });
+#endif
+}
+
+template<>
+void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
+                               const OpContext& ctx, const std::vector<TBlob>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 11U);
+  BatchNormParam param = nnvm::get<BatchNormParam>(attrs.parsed);
+  std::vector<TBlob> out_grad(1, inputs[0]);
+  std::vector<TBlob> in_data(inputs.begin() + 3, inputs.begin() + 6);
+  std::vector<TBlob> aux_states(inputs.begin() + 6, inputs.begin() + 8);
+  std::vector<TBlob> out_data(inputs.begin() + 8, inputs.end());
+  std::vector<TBlob> in_grad(outputs.begin(), outputs.begin() + 3);
+  int dtype = inputs[0].type_flag_;
+  TShape shape = inputs[0].shape_;
+
+  param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
+#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
+  if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
+      && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
+    MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
+      GetCuDNNOp<DType>(param).Backward(ctx, out_grad, in_data, out_data,
+        req, in_grad, aux_states);
+    })
+  } else {
+    MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, {
+      BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad,
+          in_data, out_data, req, in_grad, aux_states);
+    })
+  }
+#else
+  MSHADOW_REAL_TYPE_SWITCH_EX(out_grad[0].type_flag_, DType, AccReal, {
+    BatchNormBackward<gpu, DType, AccReal>(ctx, param, out_grad,
+        in_data, out_data, req, in_grad, aux_states);
+  });
 #endif
-  return op;
 }
 
+NNVM_REGISTER_OP(BatchNorm)
+.set_attr<FCompute>("FCompute<gpu>", BatchNormCompute<gpu>);
+
+NNVM_REGISTER_OP(_backward_BatchNorm)
+.set_attr<FCompute>("FCompute<gpu>", BatchNormGradCompute<gpu>);
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/nn/concat-inl.h b/src/operator/nn/concat-inl.h
new file mode 100644
index 0000000..a7f1fa8
--- /dev/null
+++ b/src/operator/nn/concat-inl.h
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file concat-inl.h
+ * \brief
+ * \author Bing Xu
+*/
+#ifndef MXNET_OPERATOR_NN_CONCAT_INL_H_
+#define MXNET_OPERATOR_NN_CONCAT_INL_H_
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <cstring>
+#include <map>
+#include <string>
+#include <vector>
+#include <utility>
+#include "../operator_common.h"
+#include "../channel_op_common.h"
+#include "../tensor/broadcast_reduce_op.h"
+
+namespace mxnet {
+namespace op {
+
+namespace concat_enum {
+enum ConcatOpInputs {kData0, kData1, kData2, kData3, kData4};
+enum ConcatOpResource {kTempSpace};
+enum ConcatOpOutputs {kOut};
+}  // namespace concat_enum
+
+struct ConcatParam : public dmlc::Parameter<ConcatParam> {
+  int num_args;
+  int dim;
+  DMLC_DECLARE_PARAMETER(ConcatParam) {
+    DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
+    .describe("Number of inputs to be concated.");
+    DMLC_DECLARE_FIELD(dim).set_default(1)
+    .describe("the dimension to be concated.");
+  }
+};  // struct ConcatParam
+
+template<typename xpu, typename DType>
+class ConcatOp {
+ public:
+  void Init(const ConcatParam &param) {
+    this->size_ = param.num_args;
+    this->dimension_ = param.dim;
+  }
+
+  void Forward(const OpContext &ctx,
+               const std::vector<TBlob> &in_data,
+               const std::vector<OpReqType> &req,
+               const std::vector<TBlob> &out_data) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(static_cast<int>(in_data.size()), size_);
+    CHECK_EQ(out_data.size(), 1U);
+    int axis = CheckAxis(dimension_, in_data[concat_enum::kData0].ndim());
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    std::vector<Tensor<xpu, 3, DType> > data(size_);
+    Tensor<xpu, 3, DType> out;
+    size_t leading = 1, trailing = 1;
+    for (int i = 0; i < axis; ++i) {
+      leading *= out_data[concat_enum::kOut].shape_[i];
+    }
+    for (int i = axis + 1; i < out_data[concat_enum::kOut].ndim(); ++i) {
+      trailing *= out_data[concat_enum::kOut].shape_[i];
+    }
+    size_t mid = out_data[concat_enum::kOut].shape_[axis];
+    Shape<3> oshape = Shape3(leading, mid, trailing);
+    out = out_data[concat_enum::kOut].get_with_shape<xpu, 3, DType>(oshape, s);
+
+    for (int i = 0; i < size_; ++i) {
+      Shape<3> dshape = Shape3(leading, in_data[i].shape_[axis], trailing);
+      data[i] = in_data[i].get_with_shape<xpu, 3, DType>(dshape, s);
+    }
+    Concatenate(data, &out, 1, req[concat_enum::kOut]);
+  }
+
+  void Backward(const OpContext &ctx, const TBlob &out_grad,
+                const std::vector<OpReqType> &req,
+                const std::vector<TBlob> &in_grad) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    CHECK_EQ(in_grad.size(), static_cast<size_t>(size_));
+    int axis = CheckAxis(dimension_, out_grad.ndim());
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    std::vector<Tensor<xpu, 3, DType> > grad_in(size_);
+    Tensor<xpu, 3, DType> grad;
+    size_t leading = 1, trailing = 1;
+    for (int i = 0; i < axis; ++i) {
+      leading *= out_grad.shape_[i];
+    }
+    for (int i = axis + 1; i < out_grad.ndim(); ++i) {
+      trailing *= out_grad.shape_[i];
+    }
+    size_t mid = out_grad.shape_[axis];
+    Shape<3> oshape = Shape3(leading, mid, trailing);
+    grad = out_grad.get_with_shape<xpu, 3, DType>(oshape, s);
+
+    for (int i = 0; i < size_; ++i) {
+      Shape<3> dshape = Shape3(leading, in_grad[i].shape_[axis], trailing);
+      grad_in[i] = in_grad[i].get_with_shape<xpu, 3, DType>(dshape, s);
+    }
+    Split(grad, &grad_in, 1, req);
+  }
+
+ private:
+  int size_;
+  int dimension_;
+};  // class ConcatOp
+
+template<typename xpu>
+void ConcatCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                   const std::vector<TBlob>& inputs,
+                   const std::vector<OpReqType>& req,
+                   const std::vector<TBlob>& outputs) {
+  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kData0].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Forward(ctx, inputs, req, outputs);
+  });
+}
+
+template<typename xpu>
+void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
+  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
+  MSHADOW_TYPE_SWITCH(inputs[concat_enum::kOut].type_flag_, DType, {
+    ConcatOp<xpu, DType> op;
+    op.Init(param);
+    op.Backward(ctx, inputs[concat_enum::kOut], req, outputs);
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_OPERATOR_NN_CONCAT_INL_H_
diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
new file mode 100644
index 0000000..81dc95f
--- /dev/null
+++ b/src/operator/nn/concat.cc
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2015 by Contributors
+ * \file concat.cc
+ * \brief
+ * \author Bing Xu
+*/
+
+#include "./concat-inl.h"
+#include "./mkldnn/mkldnn_ops-inl.h"
+#include "./mkldnn/mkldnn_base-inl.h"
+#include "../../common/utils.h"
+
+namespace mxnet {
+namespace op {
+
+static bool ConcatShape(const nnvm::NodeAttrs& attrs,
+                        std::vector<TShape> *in_shape,
+                        std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
+  TShape dshape;
+  index_t size = 0;
+  bool has_zero = false;
+  int axis = -1;
+  for (int i = 0; i < param_.num_args; ++i) {
+    TShape tmp = (*in_shape)[i];
+    if (tmp.ndim()) {
+      axis = CheckAxis(param_.dim, tmp.ndim());
+      has_zero = tmp[axis] == 0 || has_zero;
+      size += tmp[axis];
+      tmp[axis] = 0;
+      shape_assign(&dshape, tmp);
+    }
+  }
+
+  TShape tmp = (*out_shape)[0];
+  if (tmp.ndim()) {
+    axis = CheckAxis(param_.dim, tmp.ndim());
+    tmp[axis] = 0;
+    shape_assign(&dshape, tmp);
+  }
+
+  if (dshape.ndim() == 0) return false;
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    CHECK(shape_assign(&(*in_shape)[i], dshape))
+        << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
+  }
+
+  if (!has_zero) dshape[axis] = size;
+  CHECK(shape_assign(&(*out_shape)[0], dshape))
+      << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
+
+  return dshape.Size() != 0;
+}
+
+static bool ConcatType(const nnvm::NodeAttrs& attrs,
+                       std::vector<int> *in_type,
+                       std::vector<int> *out_type) {
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  int dtype = -1;
+
+  for (size_t i = 0; i < in_type->size(); ++i) {
+    if (dtype == -1) {
+      dtype = in_type->at(i);
+    } else {
+      CHECK(in_type->at(i) == dtype ||
+            in_type->at(i) == -1) <<
+          "Non-uniform data type in Concat";
+    }
+  }
+
+  if (dtype == -1) {
+    LOG(FATAL) << "Not enough information to infer type in Concat.";
+    return false;
+  }
+
+  size_t nin = param_.num_args;
+  in_type->clear();
+  for (size_t i = 0; i < nin; ++i) in_type->push_back(dtype);
+
+  out_type->clear();
+  out_type->push_back(dtype);
+
+  return true;
+}
+
+inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
+                                                 const int dev_mask,
+                                                 DispatchMode* dispatch_mode,
+                                                 std::vector<int> *in_attrs,
+                                                 std::vector<int> *out_attrs) {
+  CHECK(!in_attrs->empty());
+  CHECK_EQ(out_attrs->size(), 1U);
+  DispatchMode wanted_mode;
+#if MXNET_USE_MKLDNN == 1
+  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
+  if (dev_mask == mshadow::cpu::kDevMask
+      && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
+      && param.dim > 0)
+    wanted_mode = DispatchMode::kFComputeEx;
+  else
+#endif
+    wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
+                                             const int dev_mask,
+                                             DispatchMode* dispatch_mode,
+                                             std::vector<int> *in_attrs,
+                                             std::vector<int> *out_attrs) {
+  DispatchMode wanted_mode;
+#if MXNET_USE_MKLDNN == 1
+  const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
+  if (dev_mask == mshadow::cpu::kDevMask
+      && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
+      && param.dim > 0)
+    wanted_mode = DispatchMode::kFComputeEx;
+  else
+#endif
+    wanted_mode = DispatchMode::kFCompute;
+  return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
+                             dispatch_mode, wanted_mode);
+}
+
+#if MXNET_USE_MKLDNN == 1
+static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
+                               const OpContext& op_ctx,
+                               const std::vector<NDArray>& inputs,
+                               const std::vector<OpReqType>& req,
+                               const std::vector<NDArray>& outputs) {
+  CHECK(!inputs.empty());
+  CHECK_EQ(outputs.size(), 1U);
+  CHECK_EQ(req.size(), 1U);
+  if (req[0] == kNullOp) return;
+  // MKLDNN support 2D and 4D concat
+  if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
+      && inputs[0].dtype() == mshadow::kFloat32) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
+    return;
+  }
+  FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
+}
+
+static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
+      && inputs[0].dtype() == mshadow::kFloat32) {
+    MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+    MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
+    MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+    return;
+  }
+  FallBackCompute(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+#endif
+
+struct ConcatGrad {
+  const char *op_name;
+  std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
+                                          const std::vector<nnvm::NodeEntry>& ograds) const {
... 11852 lines suppressed ...

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.