You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by jx...@apache.org on 2018/05/25 17:11:56 UTC

[incubator-mxnet] branch master updated: Fix bugs in MKLDNN. (#10979)

This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d497b37  Fix bugs in MKLDNN. (#10979)
d497b37 is described below

commit d497b37876ffb5d9bc01812bf7f295039bfe35f6
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Fri May 25 10:11:45 2018 -0700

    Fix bugs in MKLDNN. (#10979)
    
    * Fix bugs in MKLDNN.
    
    * add more test cases.
    
    * Fix CopyFrom when it's the view of an NDArray.
    
    * add test.
    
    * check same shape correctly.
    
    * add unit test for CopyFrom.
    
    * Fix warning.
    
    * Add test sum.
    
    * fix sum.
    
    * Fix fallback.
    
    * Fix fallback of sum.
    
    * add tests.
    
    * Update mkldnn.cc
---
 src/ndarray/ndarray.cc                          | 111 +++++++++-------
 src/operator/nn/mkldnn/mkldnn_base.cc           |   5 +-
 src/operator/nn/mkldnn/mkldnn_sum.cc            |  22 +++-
 src/operator/tensor/elemwise_binary_op_basic.cc |  12 +-
 tests/cpp/operator/mkldnn.cc                    | 165 +++++++++++++++++++++---
 5 files changed, 235 insertions(+), 80 deletions(-)

diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index d87e8bc..94d3d90 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -200,6 +200,7 @@ NDArray NDArray::MKLDNNDataReshape(const TShape &shape) const {
     ret.ptr_->delay_alloc = false;
     ret.ptr_->static_data = true;
     ret.byte_offset_ = byte_offset_;
+    ret.reuse_ = false;
     return ret;
   }
 }
@@ -217,6 +218,7 @@ NDArray NDArray::Reshape(const TShape &shape) const {
   // Otherwise, reshape only works on the default layout.
   CHECK_EQ(storage_type(), kDefaultStorage);
   ret.shape_ = shape;
+  ret.reuse_ = false;
   return ret;
 }
 
@@ -249,6 +251,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
   MSHADOW_TYPE_SWITCH(ret.dtype(), DType, {
     ret.byte_offset_ += begin * length * sizeof(DType);
   });
+  ret.reuse_ = false;
   ret.shape_[0] = end - begin;
   return ret;
 }
@@ -555,6 +558,7 @@ NDArray NDArray::Reorder2Default() const {
   // reshape as needed
   ret.shape_ = shape_;
   ret.byte_offset_ = byte_offset_;
+  ret.reuse_ = false;
   return ret;
 }
 
@@ -584,39 +588,39 @@ void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc)
 
 const mkldnn::memory *NDArray::GetMKLDNNData() const {
   CHECK(storage_type() == kDefaultStorage);
+  bool is_view = IsView();
   if (IsMKLDNNData()) {
     // If this array uses MKLDNN layout, we have to make sure it's not a view.
     // Otherwise, we'll have to change the layout inside the array.
-    CHECK(!IsView());
+    CHECK(!is_view);
     MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     // If this array uses MKLDNN format, we should return now. Otherwise,
     // SetMKLMem may mess up mkl_mem_.
     return ptr_->mkl_mem_->GetRaw();
-  }
-  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
-  MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
-  if (IsView()) {
-    mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc();
-    // Sliced array must use the default layout.
-    CHECK_EQ(GetDefaultFormat(pd.desc()), pd.desc().data.format);
-    void *off_addr = static_cast<char *>(ptr_->mkl_mem_->GetDataHandle())
-        + byte_offset_;
-
+  } else if (is_view) {
+    // If this is a view, we can't create a MKLDNN memory for the chunk
+    // because we don't have the complete data type and shape information for
+    // the chunk.
+    void *off_addr = static_cast<char *>(ptr_->shandle.dptr) + byte_offset_;
     // Create the primitive desc for the new mkldnn memory.
     mkldnn::memory::dims dims(shape().ndim());
     for (size_t i = 0; i < dims.size(); i++)
       dims[i] = shape()[i];
     mkldnn::memory::format cpp_format = static_cast<mkldnn::memory::format>(
         GetDefaultFormat(shape().ndim()));
-    mkldnn::memory::data_type cpp_type = static_cast<mkldnn::memory::data_type>(
-        pd.desc().data.data_type);
+    mkldnn::memory::data_type cpp_type = get_mkldnn_type(dtype_);
     mkldnn::memory::desc data_md(dims, cpp_type, cpp_format);
-    mkldnn::memory::primitive_desc new_pd(data_md, pd.get_engine());
+    mkldnn::memory::primitive_desc new_pd(data_md,
+                                          CpuEngine::Get()->get_engine());
 
     std::shared_ptr<mkldnn::memory> ret(new mkldnn::memory(new_pd, off_addr));
     MKLDNNStream::Get()->RegisterMem(ret);
     return ret.get();
   } else {
+    // If this isn't a view, we can create a MKLDNN memory and store it in the
+    // chunk.
+    ptr_->SetMKLMem(shape_, dtype_);
+    MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     return ptr_->mkl_mem_->GetRaw();
   }
 }
@@ -637,20 +641,23 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
   MKLDNNStream *stream = MKLDNNStream::Get();
   // If this array uses MKLDNN layout, we have to make sure it's not a view.
   // Otherwise, we'll have to change the layout inside the array.
-  if (IsMKLDNNData())
-    CHECK(!IsView());
-  ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
-                  dtype_);
-  stream->RegisterMem(ptr_->mkl_mem_->GetMem());
-  mkldnn::memory::desc from_desc = mem.get_primitive_desc().desc();
-  mkldnn::memory::desc this_desc = ptr_->mkl_mem_->GetPrimitiveDesc().desc();
+
+  if (IsMKLDNNData() && IsView())
+    ptr_->Reorder2Default();
+
+  const mkldnn::memory *this_mem = GetMKLDNNData();
+  mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
+  mkldnn::memory::desc from_desc = from_pd.desc();
+  mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
+  mkldnn::memory::desc this_desc = this_pd.desc();
   mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
+  mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
   if (IsView()) {
     // Sliced array must use the default layout.
     CHECK_EQ(GetDefaultFormat(this_desc), this_desc.data.format);
   }
   // It's possible that the memory and the NDArray don't have the same shape.
-  if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)
+  if (!same_shape(this_desc, from_desc)
       // If the source memory uses the default layout, we can reshape directly.
       && from_def_format == from_desc.data.format) {
     // In this case, we can simply create a new MKLDNN memory for the required
@@ -660,15 +667,14 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
     auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
     mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine());
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
-  } else if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (!same_shape(this_desc, from_desc)) {
     // In this case, the source memory stores data in a customized layout. We
     // need to reorganize the data in memory before we can reshape.
-    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(mem.get_primitive_desc(),
-                                                             from_def_format);
+    mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
     mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
     stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
     // Now we can reshape it
@@ -677,45 +683,40 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
     auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
     auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
     mkldnn::memory::desc data_md(dims, this_dtype, this_format);
-    mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine());
+    mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
     mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
     stream->RegisterMem(tmp_mem);
-    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
-  } else if (mem.get_primitive_desc() == ptr_->mkl_mem_->GetPrimitiveDesc()) {
+    stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
+  } else if (from_pd == this_pd) {
     // If the layout is the same, we can just copy data.
-    stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
+    stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
   } else {
-    mkldnn_memory_format_t src_def = GetDefaultFormat(mem.get_primitive_desc().desc());
-    mkldnn_memory_format_t dst_def = ptr_->mkl_mem_->GetDefaultFormat();
     // If both are not using the default layouts. There isn't much we can do,
     // other than reorder data layout directly.
-    if (dst_def != ptr_->mkl_mem_->GetFormat()
-        && src_def != mem.get_primitive_desc().desc().data.format) {
-      stream->RegisterPrim(mkldnn::reorder(mem, *ptr_->mkl_mem_->GetRaw()));
-    } else if (dst_def == ptr_->mkl_mem_->GetFormat()) {
+    if (this_def_format != this_desc.data.format
+        && from_def_format != from_desc.data.format) {
+      stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
+    } else if (this_def_format == this_desc.data.format) {
       // If the dest mem uses the default memory layout, we can simply use
       // the default format of the source memory to improve perf of reorder.
-      mkldnn::memory::primitive_desc pd = ptr_->mkl_mem_->GetPrimitiveDesc(src_def);
-      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, ptr_->mkl_mem_->GetDataHandle()));
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
+                                                           from_def_format);
+      mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
       stream->RegisterMem(tmp_mem);
       stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
     } else {
       // If the src mem uses the default memory layout, we can use
       // the default format of the source memory to improve perf.
-      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(mem.get_primitive_desc(), dst_def);
+      mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
+                                                           this_def_format);
       mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
       stream->RegisterMem(tmp_mem);
-      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->mkl_mem_->GetRaw()));
+      stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
     }
   }
 }
-mkldnn::memory::primitive_desc GetPrimitiveDesc(mkldnn::memory::primitive_desc pd,
-                                                mkldnn_memory_format_t format);
 
 mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &desc) {
-  // This array shouldn't be a view.
-  CHECK(!IsView());
-
   if (desc.get_size() != shape().Size() * GetTypeSize(dtype_)) {
     LOG(FATAL) << "The size of NDArray doesn't match the requested MKLDNN memory desc";
     return nullptr;
@@ -726,10 +727,26 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::primitive_desc &
   mkldnn_memory_format_t def_format = GetDefaultFormat(_desc.desc());
   // If the required format is a default format, we don't need to worry about the shape.
   // If the shape isn't the same, it actually implicitly reshapes data.
-  if (required_format == def_format) {
+  if (required_format == def_format && !IsView()) {
     ptr_->SetMKLMem(shape_, dtype_);
     MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
     return GetMKLDNNExact(ptr_->mkl_mem_->GetRaw(), desc);
+  } else if (required_format == def_format) {
+    ptr_->CheckAndAlloc();
+    CHECK(ptr_->shandle.dptr);
+    // When this is a view and a user wants the default layout, we can simply
+    // create a new mkldnn memory that points to the right memory.
+    std::shared_ptr<mkldnn::memory> mem(new mkldnn::memory(
+            desc, static_cast<char *>(ptr_->shandle.dptr) + byte_offset_));
+    MKLDNNStream::Get()->RegisterMem(mem);
+    return mem.get();
+  } else if (IsView()) {
+    // If this is a view and a user wants to write data to it with special
+    // a MKLDNN format, we should reorder the data in the array and return NULL.
+    // In this way, the user will create a new NDArray for the special format
+    // and copy data back.
+    ptr_->Reorder2Default();
+    return nullptr;
   }
 
   if (ptr_->mkl_mem_)
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index 56abc6b..1bd1581 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -297,10 +297,7 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
     } else {
       if (in_bufs.empty())
         in_bufs.reserve(inputs.size());
-      in_bufs.emplace_back(inputs[i].shape(), inputs[i].ctx(),
-                           false, inputs[i].dtype());
-      const mkldnn::memory *mem = inputs[i].GetMKLDNNData();
-      in_bufs.back().CopyFrom(*mem);
+      in_bufs.push_back(inputs[i].Reorder2Default());
       in_blobs[i] = in_bufs.back().data();
     }
   }
diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc
index 361ce58..fdbfb15 100644
--- a/src/operator/nn/mkldnn/mkldnn_sum.cc
+++ b/src/operator/nn/mkldnn/mkldnn_sum.cc
@@ -59,8 +59,15 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   std::vector<float> scales(inputs.size(), 1);
   in_prims.reserve(inputs.size());
   bool pd_same = true;
+  std::vector<NDArray> in_bufs(inputs.size());
   for (size_t i = 0; i < inputs.size(); i++) {
-    auto in_mem = inputs[i].GetMKLDNNData();
+    const mkldnn::memory *in_mem;
+    if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
+      in_bufs[i] = inputs[i].Reorder2Default();
+      in_mem = in_bufs[i].GetMKLDNNData();
+    } else {
+      in_mem = inputs[i].GetMKLDNNData();
+    }
     in_prims.push_back(*in_mem);
     in_pds[i] = in_mem->get_primitive_desc();
   }
@@ -68,9 +75,16 @@ void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   mkldnn::sum::primitive_desc pdesc(scales, in_pds);
   pd_same = pd_same && (pdesc.dst_primitive_desc() == in_pds[0]);
   auto out_mem = const_cast<NDArray&>(out_data).CreateMKLDNNData(pdesc.dst_primitive_desc());
-  bool addr_same = out_mem->get_data_handle() == inputs[0].GetMKLDNNData()->get_data_handle();
-  if ((req == kWriteTo) ||
-      (req == kWriteInplace && pd_same && addr_same)) {
+  bool addr_same = false;
+  const void *first_data_handle;
+  if (in_bufs[0].is_none())
+    first_data_handle = inputs[0].GetMKLDNNData()->get_data_handle();
+  else
+    first_data_handle = in_bufs[0].GetMKLDNNData()->get_data_handle();
+  if (out_mem)
+    addr_same = out_mem->get_data_handle() == first_data_handle;
+  if (((req == kWriteTo) || (req == kWriteInplace && pd_same && addr_same))
+      && out_mem) {
     // do sum computation directly on output NDArray
     MKLDNNStream *stream = MKLDNNStream::Get();
     stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *out_mem));
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc
index dbb26ea..9b5b9d3 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -43,16 +43,8 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
     return;
   } else if (inputs[0].storage_type() == kDefaultStorage
              && inputs[1].storage_type() == kDefaultStorage) {
-    // This happens if inputs are supposed to be in MKLDNN format
-    // but MKLDNN doesn't support the data type or the shape. We're
-    // forced to convert it to the default format.
-    std::vector<TBlob> in_blobs(2);
-    std::vector<TBlob> out_blobs(1);
-    in_blobs[0] = inputs[0].data();
-    in_blobs[1] = inputs[1].data();
-    out_blobs[0] = outputs[0].data();
-    ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>(attrs, ctx, in_blobs,
-                                                         req, out_blobs);
+    FallBackCompute(ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::plus>,
+                    attrs, ctx, inputs, req, outputs);
     return;
   }
 #endif
diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc
index bbff531..1d29219 100644
--- a/tests/cpp/operator/mkldnn.cc
+++ b/tests/cpp/operator/mkldnn.cc
@@ -89,12 +89,17 @@ TEST(MKLDNN_UTIL_FUNC, MemFormat) {
 }
 
 // Init arrays with the default layout.
-static void InitArray(NDArray *arr) {
+static void InitArray(NDArray *arr, bool is_rand = false) {
   const TBlob &blob = arr->data();
   mshadow::default_real_t *data = blob.dptr<mshadow::default_real_t>();
   size_t size = blob.Size();
-  for (size_t i = 0; i < size; i++)
-    data[i] = i;
+  if (is_rand) {
+    for (size_t i = 0; i < size; i++)
+      data[i] = std::rand();
+  } else {
+    for (size_t i = 0; i < size; i++)
+      data[i] = i;
+  }
 }
 
 // Init arrays with the specified layout.
@@ -361,6 +366,15 @@ OpAttrs GetLeakyReluOp() {
   return attrs;
 }
 
+OpAttrs GetSumOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("elemwise_add");
+  attrs.dispatches.resize(2);
+  attrs.dispatches[0] = DispatchMode::kFCompute;
+  attrs.dispatches[1] = DispatchMode::kFComputeEx;
+  return attrs;
+}
+
 /*
  * We want to get a few types of NDArrays for testing:
  * 1. Normal NDArray
@@ -418,34 +432,65 @@ std::vector<NDArray> GetTestInputArrays() {
  *    pass them to all operators.
  *    In the inference mode, the MKLDNN memory in the weight array will be
  *    reordered to 5 dimensions.
- * 4. Reused NDArray (this is created by the MXNet executor). This type of
+ * 4. Reshaped/sliced NDArray
+ * 5. Reused NDArray (this is created by the MXNet executor). This type of
  *    NDArrays can only be used as output arrays.
+ * 6. Reused NDArray converted from an array with a different data type.
+ * 7. Reused reshaped/sliced NDArray.
+ * 8. Reused NDArray with MKLDNN layout.
+ * 9. Reused NDArray with MKLDNN layout of different dimensions.
  */
 std::vector<NDArray> GetTestOutputArrays(const TShape &shape,
                                          const std::vector<mkldnn::memory::primitive_desc> &pds) {
   std::vector<NDArray> in_arrs;
+  // Type 1.
   in_arrs.emplace_back(shape, Context());
-  InitArray(&in_arrs.back());
+  InitArray(&in_arrs.back(), true);
+
+  // Type 4.
+  TShape tmp_shape = shape;
+  tmp_shape[0] = shape[0] * 2;
+  NDArray arr0(tmp_shape, Context());
+  InitArray(&arr0, true);
+  in_arrs.emplace_back(arr0.Slice(1, shape[0] + 1));
 
+  // Type 5.
   // Get a reused version.
   nnvm::TShape s(1);
   s[0] = shape.Size();
-  NDArray arr(s, Context());
-  arr = arr.AsArray(shape, arr.dtype());
-  InitArray(&arr);
-  in_arrs.emplace_back(arr);
+  NDArray arr1(s, Context());
+  arr1 = arr1.AsArray(shape, arr1.dtype());
+  InitArray(&arr1, true);
+  in_arrs.emplace_back(arr1);
+
+  // Type 6.
+  s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag);
+  NDArray arr2(s, Context(), true, mshadow::kUint8);
+  arr2 = arr2.AsArray(shape, mshadow::default_type_flag);
+  InitArray(&arr2, true);
+  in_arrs.emplace_back(arr2);
+
+  // Type 7
+  s[0] = shape.Size() * GetTypeSize(mshadow::default_type_flag) * 2;
+  NDArray arr3(s, Context(), true, mshadow::kUint8);
+  tmp_shape[0] = shape[0] * 2;
+  arr3 = arr3.AsArray(tmp_shape, mshadow::default_type_flag);
+  InitArray(&arr3, true);
+  in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1));
 
   for (auto pd : pds) {
     if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
       continue;
 
+    // Type 2, 3.
     in_arrs.emplace_back(shape, Context());
     InitMKLDNNArray(&in_arrs.back(), pd, true);
 
+    // Type 8, 9.
     // Get a reused version.
     nnvm::TShape s(1);
     s[0] = shape.Size();
-    arr = NDArray(s, Context());
+    NDArray arr = NDArray(s, Context());
     arr = arr.AsArray(shape, arr.dtype());
     InitMKLDNNArray(&arr, pd, true);
     in_arrs.emplace_back(arr);
@@ -453,10 +498,10 @@ std::vector<NDArray> GetTestOutputArrays(const TShape &shape,
   return in_arrs;
 }
 
-using VerifyFunc = std::function<void (const NDArray &in_arr, const NDArray &arr)>;
+using VerifyFunc = std::function<void (const std::vector<NDArray *> &in_arrs, const NDArray &arr)>;
 
-void VerifyCopyResult(const NDArray &in_arr, const NDArray &arr) {
-  NDArray tmp1 = in_arr.Reorder2Default();
+void VerifyCopyResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) {
+  NDArray tmp1 = in_arrs[0]->Reorder2Default();
   NDArray tmp2 = arr.Reorder2Default();
   EXPECT_EQ(tmp1.shape().Size(), tmp2.shape().Size());
   TBlob d1 = tmp1.data();
@@ -465,6 +510,40 @@ void VerifyCopyResult(const NDArray &in_arr, const NDArray &arr) {
                    tmp1.shape().Size() * sizeof(mshadow::default_real_t)), 0);
 }
 
+void VerifySumResult(const std::vector<NDArray *> &in_arrs, const NDArray &arr) {
+  NDArray in1 = in_arrs[0]->Reorder2Default();
+  NDArray in2 = in_arrs[1]->Reorder2Default();
+  NDArray out = arr.Reorder2Default();
+  EXPECT_EQ(in1.shape().Size(), in2.shape().Size());
+  EXPECT_EQ(in1.shape().Size(), out.shape().Size());
+
+  mshadow::default_real_t *d1 = in1.data().dptr<mshadow::default_real_t>();
+  mshadow::default_real_t *d2 = in2.data().dptr<mshadow::default_real_t>();
+  mshadow::default_real_t *o = out.data().dptr<mshadow::default_real_t>();
+  for (size_t i = 0; i < in1.shape().Size(); i++)
+    EXPECT_EQ(d1[i] + d2[i], o[i]);
+}
+
+TEST(MKLDNN_NDArray, CopyFrom) {
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArray> in_arrs = GetTestInputArrays();
+  for (auto in_arr : in_arrs) {
+    std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr.shape(), pds);
+    for (auto out_arr : out_arrs) {
+      if (in_arr.IsMKLDNNData() && in_arr.IsView())
+        in_arr = in_arr.Reorder2Default();
+      const mkldnn::memory *mem = in_arr.GetMKLDNNData();
+      out_arr.CopyFrom(*mem);
+      MKLDNNStream::Get()->Submit();
+      std::vector<NDArray *> inputs(1);
+      inputs[0] = &in_arr;
+      VerifyCopyResult(inputs, out_arr);
+    }
+  }
+}
+
 void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
   std::vector<NDArray*> inputs(1);
   std::vector<NDArray*> outputs(1);
@@ -485,7 +564,53 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
         Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
                                     outputs, req, dispatch, mxnet::OpStatePtr());
         out_arr.WaitToRead();
-        verify_fn(in_arr, out_arr);
+        verify_fn(inputs, out_arr);
+      }
+    }
+  }
+
+  for (auto dispatch : dispatches) {
+    in_arrs = GetTestInputArrays();
+    for (auto arr : in_arrs) {
+      // If the array is a view, we shouldn't write data to it.
+      if (arr.IsView())
+        continue;
+
+      NDArray orig = arr.Copy(arr.ctx());
+      req[0] = kWriteInplace;
+      inputs[0] = &arr;
+      outputs[0] = &arr;
+      Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
+                                  dispatch, mxnet::OpStatePtr());
+      arr.WaitToRead();
+      inputs[0] = &orig;
+      verify_fn(inputs, arr);
+    }
+  }
+}
+
+void TestBinaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
+  std::vector<NDArray*> inputs(2);
+  std::vector<NDArray*> outputs(1);
+  std::vector<OpReqType> req(1);
+  std::vector<DispatchMode> dispatches = attrs.dispatches;
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArray> in_arrs = GetTestInputArrays();
+  for (auto in_arr1 : in_arrs) {
+    for (auto dispatch : dispatches) {
+      std::vector<NDArray> out_arrs = GetTestOutputArrays(in_arr1.shape(), pds);
+      for (auto out_arr : out_arrs) {
+        req[0] = kWriteTo;
+        inputs[0] = &in_arr1;
+        inputs[1] = &in_arr1;
+        outputs[0] = &out_arr;
+        Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
+                                    outputs, req, dispatch, mxnet::OpStatePtr());
+        out_arr.WaitToRead();
+        verify_fn(inputs, out_arr);
       }
     }
   }
@@ -500,11 +625,15 @@ void TestUnaryOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
       NDArray orig = arr.Copy(arr.ctx());
       req[0] = kWriteInplace;
       inputs[0] = &arr;
+      inputs[1] = &arr;
       outputs[0] = &arr;
       Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
                                   dispatch, mxnet::OpStatePtr());
       arr.WaitToRead();
-      verify_fn(orig, arr);
+      std::vector<NDArray *> orig_inputs(2);
+      orig_inputs[0] = &orig;
+      orig_inputs[1] = &orig;
+      verify_fn(orig_inputs, arr);
     }
   }
 }
@@ -514,4 +643,10 @@ TEST(IMPERATIVE, UnaryOp) {
   TestUnaryOp(attrs, VerifyCopyResult);
 }
 
+
+TEST(IMPERATIVE, BinaryOp) {
+  OpAttrs attrs = GetSumOp();
+  TestBinaryOp(attrs, VerifySumResult);
+}
+
 #endif

-- 
To stop receiving notification emails like this one, please contact
jxie@apache.org.