You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by ma...@apache.org on 2018/03/01 10:54:39 UTC

[incubator-mxnet] branch master updated: Fix a race condition in converting data layouts in MKLDNN. (#9862)

This is an automated email from the ASF dual-hosted git repository.

marcoabreu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f9c2689  Fix a race condition in converting data layouts in MKLDNN. (#9862)
f9c2689 is described below

commit f9c2689ec2ffd61ce123dce5857f8a797f21e4df
Author: Da Zheng <zh...@gmail.com>
AuthorDate: Thu Mar 1 11:54:35 2018 +0100

    Fix a race condition in converting data layouts in MKLDNN. (#9862)
    
    * Fix a race condition in converting data layouts.
    
    * Avoid calling data() in elemwise sum.
    
    * Fix a compilation error.
    
    * Address comments.
    
    * avoid data layout conversion inside ndarray.
    
    * Fix a compilation error.
    
    * address comments.
    
    * Reorder weight arrays in convolution async.
    
    * Fix async data reordering in NDArray.
    
    * Fix race condition in deconv.
    
    * Update ndarray.cc
    
    * Check more in NDArray.
    
    * Fix a bug in MKLDNNDataReorder.
    
    * Fix a bug in NDArray.
    
    * Simplify weight reorder in (de-)conv.
---
 include/mxnet/ndarray.h                          |  23 +++-
 src/ndarray/ndarray.cc                           | 149 +++++++++++++++--------
 src/operator/nn/mkldnn/mkldnn_base.cc            |  17 +++
 src/operator/nn/mkldnn/mkldnn_convolution.cc     |  25 ++--
 src/operator/nn/mkldnn/mkldnn_deconvolution.cc   |  22 ++--
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc |   5 +
 src/operator/tensor/cast_storage-inl.h           |   7 +-
 src/operator/tensor/elemwise_sum.cc              |  15 +--
 tests/python/gpu/test_gluon_model_zoo_gpu.py     |  10 +-
 9 files changed, 188 insertions(+), 85 deletions(-)

diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h
index 7ce41ab..67d2a27 100644
--- a/include/mxnet/ndarray.h
+++ b/include/mxnet/ndarray.h
@@ -622,12 +622,29 @@ class NDArray {
   /*
    * Reorder the memory to the specified layout.
    */
-  void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc);
+  void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc) {
+    CHECK_EQ(storage_type(), kDefaultStorage);
+    ptr_->MKLDNNDataReorder(desc);
+  }
   void Reorder2Default() {
     CHECK_EQ(storage_type(), kDefaultStorage);
     ptr_->Reorder2Default();
   }
 
+  /*
+   * These are the async version of the methods above.
+   * It changes the layout of this NDArray, but it happens after all accesses to
+   * the array are complete.
+   */
+  void Reorder2DefaultAsync();
+  void MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc);
+
+  /*
+   * This creates a new NDArray with the reordered data.
+   * It doesn't affect the data of the original NDArray.
+   */
+  NDArray Reorder2Default() const;
+
   void InvalidateMKLDNNData() {
     // Removing mkl_mem_ means the NDArray will store data in the default format.
     ptr_->mkl_mem_ = nullptr;
@@ -880,9 +897,11 @@ class NDArray {
     // Have MKL memory reference to the data in the default storage
     // or create memory for MKLDNN.
     void SetMKLMem(const TShape &shape, int dtype);
-    // In the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
+    // If the data is stored in MKLDNN layout, we reorder data in mkl_mem_ and
     // save the result in shandle.
     void Reorder2Default();
+    // Reroder data to a specified layout.
+    void MKLDNNDataReorder(const mkldnn::memory::primitive_desc &desc);
     bool IsMKLDNN() const;
     bool IsDefault() const;
 #endif
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index ae7209e..84328ea 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -375,7 +375,45 @@ void NDArray::Chunk::Reorder2Default() {
   CheckAndAlloc(def_pd.get_size());
   // TODO(zhengda) We need to avoid memory copy here.
   memcpy(shandle.dptr, def_mem->get_data_handle(), def_pd.get_size());
-  mkl_mem_.reset(new mkldnn::memory(def_pd, shandle.dptr));
+  mkl_mem_ = nullptr;
+}
+
+void NDArray::Chunk::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) {
+  // If the memory already uses the specified layout, don't do anything.
+  if (mkl_mem_ != nullptr && mkl_mem_->get_primitive_desc() == pd)
+    return;
+  auto _pd = pd;
+  auto _desc = _pd.desc();
+  auto def_format = GetDefaultFormat(_desc);
+  // If the memory is default, don't do anything.
+  if (def_format == _desc.data.format && IsDefault())
+    return;
+  // If the specified layout is default, we should use Reorder2Default.
+  if (def_format == _desc.data.format) {
+    Reorder2Default();
+    return;
+  }
+
+  std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd));
+  std::shared_ptr<mkldnn::memory> old_mem;
+  if (IsDefault()) {
+    auto def_pd = GetPrimitiveDesc(pd, def_format);
+    old_mem.reset(new mkldnn::memory(def_pd, shandle.dptr));
+  } else {
+    old_mem = this->mkl_mem_;
+  }
+  CHECK(old_mem->get_primitive_desc().desc().data.ndims == _desc.data.ndims);
+
+  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
+  std::vector<mkldnn::primitive> net;
+  net.push_back(mkldnn::reorder(*old_mem, *new_mem));
+  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+
+  CHECK(shandle.size >= pd.get_size());
+  CheckAndAlloc(pd.get_size());
+  // TODO(zhengda) We need to avoid memory copy here.
+  memcpy(shandle.dptr, new_mem->get_data_handle(), pd.get_size());
+  mkl_mem_.reset(new mkldnn::memory(pd, shandle.dptr));
 }
 
 void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
@@ -495,12 +533,56 @@ const mkldnn::memory *NDArray::GetMKLDNNDataReorder(
   }
 }
 
+NDArray NDArray::Reorder2Default() const {
+  CHECK(storage_type() == kDefaultStorage);
+
+  if (ptr_->mkl_mem_ == nullptr)
+    return *this;
+  auto format = GetDefaultFormat(ptr_->mkl_mem_->get_primitive_desc().desc());
+  if (format == ptr_->mkl_mem_->get_primitive_desc().desc().data.format)
+    return *this;
+
+  NDArray ret(shape(), ctx(), false, dtype());
+  auto def_pd = GetPrimitiveDesc(ptr_->mkl_mem_->get_primitive_desc(), format);
+  CHECK(ret.ptr_->shandle.size >= def_pd.get_size());
+  mkldnn::memory def_mem(def_pd, ret.ptr_->shandle.dptr);
+  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
+  std::vector<mkldnn::primitive> net;
+  net.push_back(mkldnn::reorder(*ptr_->mkl_mem_, def_mem));
+  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
+  return ret;
+}
+
+void NDArray::Reorder2DefaultAsync() {
+  std::vector<Engine::VarHandle> const_vars;
+  std::vector<Engine::VarHandle> mutable_vars(1, this->var());
+  NDArray tmp = *this;
+  Engine::Get()->PushAsync(
+    [tmp](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      tmp.ptr_->Reorder2Default();
+      on_complete();
+    }, ctx(), const_vars, mutable_vars,
+    FnProperty::kNormal, 0, PROFILER_MESSAGE("Reorder2Default"));
+}
+
+void NDArray::MKLDNNDataReorderAsync(const mkldnn::memory::primitive_desc &desc) {
+  std::vector<Engine::VarHandle> const_vars;
+  std::vector<Engine::VarHandle> mutable_vars(1, this->var());
+  NDArray tmp = *this;
+  Engine::Get()->PushAsync(
+    [tmp, desc](RunContext ctx, Engine::CallbackOnComplete on_complete) {
+      tmp.ptr_->MKLDNNDataReorder(desc);
+      on_complete();
+    }, ctx(), const_vars, mutable_vars,
+    FnProperty::kNormal, 0, PROFILER_MESSAGE("Reorder"));
+}
+
 const mkldnn::memory *NDArray::GetMKLDNNData() const {
   CHECK(storage_type() == kDefaultStorage);
-  // If this array uses MKLDNN layout and it's a view, we have to change its
-  // layout to the default layout.
-  if (IsMKLDNNData() && IsView())
-    ptr_->Reorder2Default();
+  // If this array uses MKLDNN layout, we have to make sure it's not a view.
+  // Otherwise, we'll have to change the layout inside the array.
+  if (IsMKLDNNData())
+    CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_, dtype_);
   // If shandle has data, the data in shandle and mkl_mem_ should match.
   if (ptr_->shandle.dptr)
@@ -534,45 +616,6 @@ const mkldnn::memory *NDArray::GetMKLDNNData() const {
   }
 }
 
-void NDArray::MKLDNNDataReorder(const mkldnn::memory::primitive_desc &pd) {
-  CHECK_EQ(storage_type(), kDefaultStorage);
-  // If the memory already uses the specified layout, don't do anything.
-  if (ptr_->mkl_mem_ != nullptr && ptr_->mkl_mem_->get_primitive_desc() == pd)
-    return;
-  auto _pd = pd;
-  auto _desc = _pd.desc();
-  auto def_format = GetDefaultFormat(_desc);
-  // If the memory is default, don't do anything.
-  if (def_format == _desc.data.format && ptr_->IsDefault())
-    return;
-  // If the specified layout is default, we should use Reorder2Default.
-  if (def_format == _desc.data.format) {
-    ptr_->Reorder2Default();
-    return;
-  }
-
-  std::shared_ptr<mkldnn::memory> new_mem(new mkldnn::memory(pd));
-  ptr_->SetMKLMem(shape_, dtype_);
-  auto old_mem = ptr_->mkl_mem_;
-  // It's possible that the specified layout has a different number of dimensions.
-  if (old_mem->get_primitive_desc().desc().data.ndims != _desc.data.ndims) {
-    // For now, we only support reorder from the default layout.
-    CHECK(ptr_->IsDefault());
-    auto def_pd = GetPrimitiveDesc(pd, def_format);
-    old_mem.reset(new mkldnn::memory(def_pd, old_mem->get_data_handle()));
-  }
-  // This may be called in MKLDNN operators. We can't use MKLDNNStream here.
-  std::vector<mkldnn::primitive> net;
-  net.push_back(mkldnn::reorder(*old_mem, *new_mem));
-  mkldnn::stream(mkldnn::stream::kind::eager).submit(net).wait();
-
-  CHECK(ptr_->shandle.size >= pd.get_size());
-  ptr_->CheckAndAlloc(pd.get_size());
-  // TODO(zhengda) We need to avoid memory copy here.
-  memcpy(ptr_->shandle.dptr, new_mem->get_data_handle(), pd.get_size());
-  ptr_->mkl_mem_.reset(new mkldnn::memory(pd, ptr_->shandle.dptr));
-}
-
 void NDArray::CopyFrom(const mkldnn::memory &mem) {
   CHECK(ptr_ != nullptr) << "The NDArray hasn't been initialized";
   if (ptr_->mkl_mem_.get() == &mem)
@@ -581,10 +624,10 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
   CHECK(mem.get_primitive_desc().get_size() == shape().Size() * GetTypeSize(dtype_))
       << "The size of NDArray doesn't match the requested MKLDNN memory desc";
   MKLDNNStream *stream = MKLDNNStream::Get();
-  // If this array uses MKLDNN layout and it's a view, we have to change its
-  // layout to the default layout.
-  if (IsMKLDNNData() && IsView())
-    ptr_->Reorder2Default();
+  // If this array uses MKLDNN layout, we have to make sure it's not a view.
+  // Otherwise, we'll have to change the layout inside the array.
+  if (IsMKLDNNData())
+    CHECK(!IsView());
   ptr_->SetMKLMem(IsView() ? ptr_->storage_shape : shape_,
                   dtype_);
   stream->RegisterMem(ptr_->mkl_mem_);
@@ -1017,6 +1060,7 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
     // with Copy().
     NDArray tmp_from = from;
     if (tmp_from.IsMKLDNNData()) {
+      // TODO(zhengda) tmp_from should be cached.
       tmp_from = NDArray(from.shape(), from.ctx(), false, from.dtype());
       auto tmp_mem = from.GetMKLDNNData();
       tmp_from.CopyFrom(*tmp_mem);
@@ -1025,7 +1069,7 @@ inline void CopyFromToDnsImpl(const NDArray& from, const NDArray& to, RunContext
     CHECK(tmp_from.IsDefaultData());
     CHECK(to.IsDefaultData());
     TBlob tmp = to.data();
-    ndarray::Copy<from_xpu, to_xpu>(from.data(), &tmp,
+    ndarray::Copy<from_xpu, to_xpu>(tmp_from.data(), &tmp,
                                     from.ctx(), to.ctx(), ctx);
   }
 #endif
@@ -1849,7 +1893,12 @@ void NDArray::SyncCopyToCPU(void *data, size_t size) const {
   if (this->ctx().dev_mask() == cpu::kDevMask) {
     this->WaitToRead();
     RunContext rctx{this->ctx(), nullptr};
-    ndarray::Copy<cpu, cpu>(this->data(), &dst,
+    NDArray src = *this;
+#if MXNET_USE_MKLDNN == 1
+    if (src.IsMKLDNNData())
+      src = this->Reorder2Default();
+#endif
+    ndarray::Copy<cpu, cpu>(src.data(), &dst,
                             Context::CPU(), Context::CPU(), rctx);
   } else {
 #if MXNET_USE_CUDA
diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc
index f21111b..edc3482 100644
--- a/src/operator/nn/mkldnn/mkldnn_base.cc
+++ b/src/operator/nn/mkldnn/mkldnn_base.cc
@@ -270,9 +270,26 @@ void FallBackCompute(FCompute fn, const nnvm::NodeAttrs &attrs,
                      const std::vector<OpReqType> &req,
                      const std::vector<NDArray> &outputs) {
   std::vector<TBlob> in_blobs(inputs.size());
+  std::vector<NDArray> in_bufs;
   for (size_t i = 0; i < in_blobs.size(); i++) {
+    // If the input data isn't stored in the default format, we shouldn't
+    // call data() directly, which will change the layout of the NDArray.
+    // Instead, we should save the converted data in another NDArray.
+    // TODO(zhengda) we should use temp space to save the converted data.
+    if (inputs[i].IsDefaultData()) {
       in_blobs[i] = inputs[i].data();
+    } else {
+      if (in_bufs.empty())
+        in_bufs.reserve(inputs.size());
+      in_bufs.emplace_back(inputs[i].shape(), inputs[i].ctx(),
+                           false, inputs[i].dtype());
+      const mkldnn::memory *mem = inputs[i].GetMKLDNNData();
+      in_bufs.back().CopyFrom(*mem);
+      in_blobs[i] = in_bufs.back().data();
+    }
   }
+  MKLDNNStream::Get()->Submit();
+
   std::vector<TBlob> out_blobs(outputs.size());
   for (size_t i = 0; i < out_blobs.size(); i++) {
     if (req[i] == kWriteTo)
diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc
index b94850a..76efc24 100644
--- a/src/operator/nn/mkldnn/mkldnn_convolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc
@@ -262,8 +262,8 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
                                const std::vector<NDArray> &out_data) {
   TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
   const ConvolutionParam& param = nnvm::get<ConvolutionParam>(attrs.parsed);
-  MKLDNNConvForward &fwd = GetConvFwd(attrs,
-      ctx.is_train, in_data[conv::kData], in_data[conv::kWeight],
+  NDArray weight = in_data[conv::kWeight];
+  MKLDNNConvForward &fwd = GetConvFwd(attrs, ctx.is_train, in_data[conv::kData], weight,
       param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]);
 
   auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
@@ -271,16 +271,23 @@ void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
     // to the default format for now.
-    if (in_data[conv::kWeight].IsMKLDNNData())
-      const_cast<NDArray &>(in_data[conv::kWeight]).Reorder2Default();
-    weight_mem = GetWeights(in_data[conv::kWeight], fwd.fwd_pd.weights_primitive_desc(),
-                            param.num_group);
+    if (weight.IsMKLDNNData())
+      // This asks the engine to change the layout of the weight array after
+      // it's used.
+      weight.Reorder2DefaultAsync();
+    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
   } else {
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
-    const_cast<NDArray &>(in_data[conv::kWeight]).MKLDNNDataReorder(
-        fwd.fwd_pd.weights_primitive_desc());
-    weight_mem = in_data[conv::kWeight].GetMKLDNNData();
+    if (weight.IsDefaultData()) {
+      weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
+      // We also need to modify the layout on the original weight array. The
+      // data conversion happens after the weight array is used.
+      weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
+    } else {
+      weight_mem = weight.GetMKLDNNData();
+      CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
+    }
   }
   auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(),
                                  req[conv::kOut]);
diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
index d336d6d..a0d3df7 100644
--- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
+++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc
@@ -234,21 +234,27 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
                                         const std::vector<NDArray> &out_data) {
   auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
       fwd_pd.diff_dst_primitive_desc());
+  NDArray weight = in_data[deconv::kWeight];
   const mkldnn::memory *weight_mem;
   if (ctx.is_train) {
     // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
     // to the default format for now.
-    if (in_data[deconv::kWeight].IsMKLDNNData())
-      const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder2Default();
-    weight_mem = GetWeights(in_data[deconv::kWeight],
-                            fwd_pd.weights_primitive_desc(),
-                            param.num_group);
+    if (weight.IsMKLDNNData())
+      // This asks the engine to reorder data after the weight array is used.
+      weight.Reorder2DefaultAsync();
+    weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
   } else {
     // For inference, we want to reorder the weight array so we don't need to
     // reorder data every time.
-    const_cast<NDArray &>(in_data[deconv::kWeight]).MKLDNNDataReorder(
-        fwd_pd.weights_primitive_desc());
-    weight_mem = in_data[deconv::kWeight].GetMKLDNNData();
+    if (weight.IsDefaultData()) {
+      weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
+      // We also need to modify the layout on the original weight array. The
+      // data conversion happens after the weight array is used.
+      weight.MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
+    } else {
+      weight_mem = weight.GetMKLDNNData();
+      CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
+    }
   }
   auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
       fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index a8b85bb..eb379f2 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -90,6 +90,11 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   const TShape& oshape = out_data[fullc::kOut].shape();
   NDArray weight = in_data[fullc::kWeight];
   NDArray data = in_data[fullc::kData];
+  // If the input data is a view of an MKLDNN array, we should create a new
+  // NDArray with reordered data.
+  if (data.IsMKLDNNData() && data.IsView())
+    data = in_data[fullc::kData].Reorder2Default();
+
   auto out_md = GetMemDesc(out_data[fullc::kOut]);
   if (data.shape().ndim() != 2 && !param.flatten) {
     data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1),
diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h
index e345bb2..46de10a 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -351,7 +351,12 @@ void CastStorageComputeImpl(const OpContext& ctx,
     CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type);
     // If one of them uses the MKLDNN layout.
     if (input.IsMKLDNNData() || output.IsMKLDNNData()) {
-      auto in_mem = input.GetMKLDNNData();
+      NDArray tmp_input = input;
+      // If the input data is MKLDNN and is a view, we need to reorder the input
+      // data first.
+      if (input.IsMKLDNNData() && input.IsView())
+        tmp_input = input.Reorder2Default();
+      const mkldnn::memory *in_mem = tmp_input.GetMKLDNNData();
       const_cast<NDArray &>(output).CopyFrom(*in_mem);
       MKLDNNStream::Get()->Submit();
     } else {
diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc
index 10154bc..8efeb85 100644
--- a/src/operator/tensor/elemwise_sum.cc
+++ b/src/operator/tensor/elemwise_sum.cc
@@ -25,6 +25,7 @@
 #include "./elemwise_sum.h"
 #include "../../ndarray/ndarray_function.h"
 #include "../nn/mkldnn/mkldnn_ops-inl.h"
+#include "../nn/mkldnn/mkldnn_base-inl.h"
 #include "../../common/utils.h"
 
 namespace mxnet {
@@ -122,19 +123,9 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
 #if MXNET_USE_MKLDNN == 1
   } else if (IsMKLDNNData(inputs)) {
     MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]);
-#endif
   } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
-    // This case happens when we want to create an MKLDNN NDArray but the type
-    // or the shape isn't supported by MKLDNN. In this case, NDArray falls back
-    // to the default storage type and, thus, we have to handle the default
-    // storage in FComputeEx.
-    std::vector<TBlob> in_blobs(inputs.size());
-    std::vector<TBlob> out_blobs(outputs.size());
-    for (size_t i = 0; i < in_blobs.size(); i++)
-      in_blobs[i] = inputs[i].data();
-    for (size_t i = 0; i < out_blobs.size(); i++)
-      out_blobs[i] = outputs[i].data();
-    ElementWiseSumCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
+    FallBackCompute(ElementWiseSumCompute<cpu>, attrs, ctx, inputs, req, outputs);
+#endif
   } else {
     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
   }
diff --git a/tests/python/gpu/test_gluon_model_zoo_gpu.py b/tests/python/gpu/test_gluon_model_zoo_gpu.py
index 6456436..378a822 100644
--- a/tests/python/gpu/test_gluon_model_zoo_gpu.py
+++ b/tests/python/gpu/test_gluon_model_zoo_gpu.py
@@ -37,7 +37,6 @@ def download_data():
     return mx.test_utils.download(
         'http://data.mxnet.io/data/val-5k-256.rec', VAL_DATA)
 
-@unittest.skip("test fails intermittently. temporarily disabled.")
 @with_seed()
 def test_inference():
     all_models = ['resnet50_v1', 'vgg19_bn', 'alexnet', #'inceptionv3',
@@ -87,7 +86,9 @@ def test_inference():
             cpu_out = cpu_model(mx.nd.array(data, ctx=mx.cpu()))
             gpu_out = gpu_model(gpu_data)
         out = cpu_out.asnumpy()
-        max_val = np.max(out)
+        max_val = np.max(np.abs(out))
+        gpu_max_val = np.max(np.abs(gpu_out.asnumpy()))
+        eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val))
         assert_almost_equal(out / max_val, gpu_out.asnumpy() / max_val, rtol=1e-3, atol=1e-3)
 
 def get_nn_model(name):
@@ -156,7 +157,10 @@ def test_training():
             gpu_out = gpu_model(gpu_data)
             cpu_loss = softmax_cross_entropy(cpu_out, label)
             gpu_loss = softmax_cross_entropy(gpu_out, gpu_label)
-        assert_almost_equal(cpu_out.asnumpy(), gpu_out.asnumpy(), rtol=1e-2, atol=1e-2)
+        max_val = np.max(np.abs(cpu_out.asnumpy()))
+        gpu_max_val = np.max(np.abs(gpu_out.asnumpy()))
+        eprint(model_name + ": CPU " + str(max_val) + ", GPU " + str(gpu_max_val))
+        assert_almost_equal(cpu_out.asnumpy() / max_val, gpu_out.asnumpy() / max_val, rtol=1e-3, atol=1e-3)
         cpu_loss.backward()
         gpu_loss.backward()
         cpu_trainer.step(batch_size)

-- 
To stop receiving notification emails like this one, please contact
marcoabreu@apache.org.