You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/06/29 15:54:52 UTC

[GitHub] szha closed pull request #11371: [MXNET-486] Create CPP test for concat MKLDNN operator

szha closed pull request #11371: [MXNET-486] Create CPP test for concat MKLDNN operator
URL: https://github.com/apache/incubator-mxnet/pull/11371
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc
index 04332456cda..266ccb1b1a1 100644
--- a/src/operator/nn/concat.cc
+++ b/src/operator/nn/concat.cc
@@ -157,7 +157,19 @@ inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
   return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
                              dispatch_mode, wanted_mode);
 }
-
+#if MXNET_USE_MKLDNN == 1
+bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
+  for (auto &arr : arrs) {
+    if (arr.IsView()) return false;
+    if (arr.dtype() != mshadow::kFloat32) return false;
+    unsigned ndim = arr.shape().ndim();
+    unsigned mkldnn_ndims =
+        static_cast<unsigned>(arr.GetMKLDNNData()->get_primitive_desc().desc().data.ndims);
+    if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
+  }
+  return true;
+}
+#endif
 static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
                                const OpContext& op_ctx,
                                const std::vector<NDArray>& inputs,
@@ -171,8 +183,7 @@ static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
       outputs[0].storage_type() == kCSRStorage) {
     ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
 #if MXNET_USE_MKLDNN == 1
-  } else if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
-      && inputs[0].dtype() == mshadow::kFloat32) {
+  } else if (SupportMKLDNNConcat(inputs)) {
     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
     MKLDNNConcatForward(attrs, op_ctx, inputs, req, outputs);
     MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
@@ -190,8 +201,7 @@ static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const std::vector<NDArray>& inputs,
                                    const std::vector<OpReqType>& req,
                                    const std::vector<NDArray>& outputs) {
-  if ((inputs[0].shape().ndim() == 2 || inputs[0].shape().ndim() == 4)
-      && inputs[0].dtype() == mshadow::kFloat32) {
+  if (SupportMKLDNNConcat(inputs)) {
     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
     MKLDNNConcatBackward(attrs, ctx, inputs, req, outputs);
     MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index dbc0e94c630..af81e1fe3ee 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -107,7 +107,7 @@ void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   std::vector<const mkldnn::memory *> data_mem;
   data_md.reserve(num_in_data);
   data_mem.reserve(num_in_data);
-  for (int i =0; i < num_in_data; i++) {
+  for (int i = 0; i < num_in_data; i++) {
     const mkldnn::memory *tmp_mem = in_data[i].GetMKLDNNData();
     mkldnn::memory::primitive_desc tmp_pd = tmp_mem->get_primitive_desc();
     data_md.push_back(tmp_pd);
@@ -138,11 +138,11 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
   mkldnn::memory::dims offsets = {0, 0, 0, 0};
   for (int i = 0; i < num_in_data; i++) {
     mkldnn::memory::dims diff_src_tz
-        = {static_cast<int>(inputs[i+1].shape()[0]),
-          static_cast<int>(inputs[i+1].shape()[1]),
-          static_cast<int>(inputs[i+1].shape()[2]),
-          static_cast<int>(inputs[i+1].shape()[3])};
-    auto diff_src_mpd = inputs[i+1].GetMKLDNNData()->get_primitive_desc();
+        = {static_cast<int>(outputs[i].shape()[0]),
+          static_cast<int>(outputs[i].shape()[1]),
+          static_cast<int>(outputs[i].shape()[2]),
+          static_cast<int>(outputs[i].shape()[3])};
+    auto diff_src_mpd = outputs[i].GetMKLDNNData()->get_primitive_desc();
     auto gradi_mem_ = CreateMKLDNNMem(outputs[i], diff_src_mpd, req[i]);
     // create view from gy to gxs[i]
     std::shared_ptr<mkldnn::view::primitive_desc> view_pd;
diff --git a/tests/cpp/operator/mkldnn.cc b/tests/cpp/operator/mkldnn.cc
index e593d00a0de..8e01216527c 100644
--- a/tests/cpp/operator/mkldnn.cc
+++ b/tests/cpp/operator/mkldnn.cc
@@ -160,6 +160,17 @@ static mkldnn::memory::primitive_desc GetMemPD(const TShape s, int dtype,
   return mkldnn::memory::primitive_desc(desc, CpuEngine::Get()->get_engine());
 }
 
+static mkldnn::memory::primitive_desc GetExpandedMemPD(
+    mkldnn::memory::primitive_desc pd, float num_input, int dim = 0) {
+  CHECK(dim < pd.desc().data.ndims) << "dimension cannot be larger than total dimensions of input";
+  nnvm::TShape s(pd.desc().data.ndims);
+  for (size_t i = 0; i < pd.desc().data.ndims; i++)
+    s[i] = pd.desc().data.dims[i];
+  s[dim] = static_cast<int>(s[dim] * num_input);
+  return GetMemPD(s, mshadow::DataType<mshadow::default_real_t>::kFlag,
+                  static_cast<mkldnn::memory::format>(pd.desc().data.format));
+}
+
 // This function gets special MKLDNN formats without knowing the specific
 // hardware configuration. Certainly, it potentially misses some format if
 // it's specific for certain array shapes. It covers at least one special format
@@ -359,9 +370,9 @@ struct OpAttrs {
 OpAttrs GetCopyOp() {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("_copy");
-  attrs.dispatches.resize(2);
   attrs.num_inputs = 1;
   attrs.num_outputs = 1;
+  attrs.dispatches.resize(2);
   attrs.dispatches[0] = DispatchMode::kFCompute;
   attrs.dispatches[1] = DispatchMode::kFComputeEx;
   return attrs;
@@ -407,9 +418,9 @@ OpAttrs GetReluBackwardsOp() {
 OpAttrs GetSumOp() {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("elemwise_add");
-  attrs.dispatches.resize(2);
   attrs.num_inputs = 2;
   attrs.num_outputs = 1;
+  attrs.dispatches.resize(2);
   attrs.dispatches[0] = DispatchMode::kFCompute;
   attrs.dispatches[1] = DispatchMode::kFComputeEx;
   return attrs;
@@ -426,6 +437,42 @@ OpAttrs GetSumBackwardsOp() {
   return attrs;
 }
 
+OpAttrs GetConcatOp(int num_args, int dim) {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("concat");
+  attrs.num_inputs = num_args;
+  attrs.num_outputs = 1;
+  attrs.attrs.dict.insert({"num_args" , std::to_string(num_args)});
+  attrs.attrs.dict.insert({"dim" , std::to_string(dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.dispatches.resize(2);
+  attrs.dispatches[0] = DispatchMode::kFCompute;
+  attrs.dispatches[1] = DispatchMode::kFComputeEx;
+  return attrs;
+}
+
+OpAttrs GetConcatBackwardsOp(int num_args, int dim) {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_Concat");
+  attrs.num_inputs = 2;
+  attrs.num_outputs = num_args;
+  attrs.attrs.dict.insert({"num_args" , std::to_string(num_args)});
+  attrs.attrs.dict.insert({"dim" , std::to_string(dim)});
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.dispatches.resize(2);
+  attrs.dispatches[0] = DispatchMode::kFCompute;
+  attrs.dispatches[1] = DispatchMode::kFComputeEx;
+  return attrs;
+}
+
+void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
+  TShape t1 = arr1.arr.shape();
+  TShape t2 = arr2.arr.shape();
+  std::stringstream ss;
+  std::cout << "Verifying: " << arr1.desc.c_str() << " " <<
+     t1 << " with " << arr2.desc.c_str() << " " << t2 << "\n";
+}
+
 /*
  * We want to get a few types of NDArrays for testing:
  * 1. Normal NDArray
@@ -446,20 +493,37 @@ OpAttrs GetSumBackwardsOp() {
  *    In the inference mode, the MKLDNN memory in the weight array will be
  *    reordered to 5 dimensions.
  *
+ *  num_inputs / dim arguments used to scale shape (used for concat backwards to enlarge input shapes)
  */
-std::vector<NDArrayAttrs> GetTestInputArrays(bool rand = false) {
+std::vector<NDArrayAttrs> GetTestInputArrays(bool rand = false, int num_inputs = 1, int dim = 0) {
   TestArrayShapes tas = GetTestArrayShapes();
   std::vector<nnvm::TShape> shapes = tas.shapes;
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
 
   std::vector<NDArrayAttrs> in_arrs;
   std::string desc;
+
+  int slice_amount = 1;
+  if (dim == 0)
+    slice_amount = num_inputs;
   for (auto shape : shapes) {
+    if (dim >= shape.ndim())
+      continue;
+    shape[dim] = shape[dim] * num_inputs;
+
     // Type 1.
     NDArray arr(shape, Context());
     in_arrs.emplace_back(arr, "Normal NDArray");
     InitDefaultArray(&in_arrs.back().arr, rand);
     for (auto pd : pds) {
+      if (num_inputs > 1) {
+        // preserve if matching layout else just expand on 0 dim
+        if (shape.ndim() == pd.desc().data.ndims)
+          pd = GetExpandedMemPD(pd, num_inputs, dim);
+        else
+          pd = GetExpandedMemPD(pd, num_inputs);
+      }
+
       if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
         continue;
 
@@ -472,8 +536,8 @@ std::vector<NDArrayAttrs> GetTestInputArrays(bool rand = false) {
            shape.ndim() << "/" << pd.desc().data.ndims;
         desc = ss.str();
       }
+      InitMKLDNNArray(&arr, pd);
       in_arrs.emplace_back(arr, desc);
-      InitMKLDNNArray(&in_arrs.back().arr, pd);
 
       // Type 4, 5, 6.
       arr = NDArray(shape, Context());
@@ -485,31 +549,12 @@ std::vector<NDArrayAttrs> GetTestInputArrays(bool rand = false) {
         desc = ss.str();
       }
       InitMKLDNNArray(&arr, pd);
-      in_arrs.emplace_back(arr.Slice(1, arr.shape()[0] - 1), desc);
+      in_arrs.emplace_back(arr.Slice(slice_amount, arr.shape()[0] - slice_amount), desc);
     }
   }
   return in_arrs;
 }
 
-TEST(MKLDNN_NDArray, GetTestInputArrays) {
-  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  int mkldnn_count = 0, mkldnn_view_count = 0;
-  for (auto arr : in_arrs) {
-    if (arr.arr.IsView() && arr.arr.IsMKLDNNData()) {
-      mkldnn_view_count++;
-      continue;
-    }
-
-    if (arr.arr.IsMKLDNNData()) {
-      mkldnn_count++;
-      continue;
-    }
-  }
-
-  EXPECT_GT(mkldnn_view_count, 0);
-  EXPECT_GT(mkldnn_count, 0);
-}
-
 /*
  * We want to get a few types of NDArrays for testing:
  * 1. Normal NDArray
@@ -527,9 +572,17 @@ TEST(MKLDNN_NDArray, GetTestInputArrays) {
  * 7. Reused reshaped/sliced NDArray.
  * 8. Reused NDArray with MKLDNN layout.
  * 9. Reused NDArray with MKLDNN layout of different dimensions.
+ *
+ * Optional num_inputs / dim args can be passed to modify input shape (used for Concat test)
  */
-std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
-                                         const std::vector<mkldnn::memory::primitive_desc> &pds) {
+std::vector<NDArrayAttrs> GetTestOutputArrays(
+    const TShape &shp,
+    const std::vector<mkldnn::memory::primitive_desc> &pds,
+    float num_inputs = 0, int dim = 0) {
+  TShape shape = shp;
+  if (num_inputs != 0)
+    shape[dim] = static_cast<int>(shape[dim] * num_inputs);
+
   std::vector<NDArrayAttrs> in_arrs;
   std::string desc;
   // Type 1.
@@ -568,11 +621,14 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
   InitDefaultArray(&arr3, true);
   in_arrs.emplace_back(arr3.Slice(1, shape[0] + 1), "Reused+Reshaped NDArray");
 
-
   for (auto pd : pds) {
     if (shape.Size() != pd.get_size() / sizeof(mshadow::default_real_t))
       continue;
 
+    if (num_inputs != 0)
+      pd = GetExpandedMemPD(pd, num_inputs);
+
+
     // Type 2, 3.
 
     arr = NDArray(shape, Context());
@@ -605,6 +661,47 @@ std::vector<NDArrayAttrs> GetTestOutputArrays(const TShape &shape,
   return in_arrs;
 }
 
+TEST(MKLDNN_NDArray, GetTestInputArraysConcat) {
+  auto in_arrs = GetTestInputArrays();
+  for (int dim = 0; dim < 5; dim++) {
+    for (int num_inputs = 2; num_inputs < 5; num_inputs++) {
+      std::vector<NDArrayAttrs> expanded_arrs = GetTestInputArrays(false, num_inputs, dim);
+      int i = 0;
+      for (auto &arr : in_arrs) {
+        if (dim >= arr.arr.shape().ndim())
+          continue;
+        auto ex_arr = expanded_arrs[i];
+        PrintVerifyMsg(arr, ex_arr);
+        EXPECT_EQ(arr.arr.shape().Size() * num_inputs, ex_arr.arr.shape().Size());
+        EXPECT_EQ(arr.arr.shape()[dim] * num_inputs, ex_arr.arr.shape()[dim]);
+        i++;
+      }
+    }
+  }
+}
+
+TEST(MKLDNN_NDArray, GetTestOutputArraysConcat) {
+  auto shapes_pds = GetTestArrayShapes();
+  std::vector<nnvm::TShape> shapes; shapes = shapes_pds.shapes;
+  std::vector<mkldnn::memory::primitive_desc> pds = shapes_pds.pds;
+  for (auto &shape : shapes) {
+    for (int dim = 0; dim < 5; dim++) {
+      for (int num_inputs = 2; num_inputs < 5; num_inputs++) {
+        if (shape.ndim() <= dim)
+          continue;
+        std::cout << "Extending " << shape << " dim " <<
+                  dim << " and " << num_inputs << "num_inputs\n";
+        auto output_arrs = GetTestOutputArrays(shape, pds, num_inputs, dim);
+        for (auto &out_arr : output_arrs) {
+          auto out_shape = out_arr.arr.shape();
+          EXPECT_EQ(shape.Size() * num_inputs, out_arr.arr.shape().Size());
+          EXPECT_EQ(shape[dim] * num_inputs, out_arr.arr.shape()[dim]);
+        }
+      }
+    }
+  }
+}
+
 void VerifyCopyResult(const std::vector<NDArray *> &in_arrs,
                       const std::vector<NDArray *> &out_arrs) {
   NDArray tmp1 = in_arrs[0]->Reorder2Default();
@@ -676,17 +773,77 @@ void VerifySumBackwardsResult(const std::vector<NDArray *> &in_arrs,
   }
 }
 
-void PrintVerifyMsg(const NDArrayAttrs &arr1, const NDArrayAttrs &arr2) {
-  TShape t1 = arr1.arr.shape();
-  TShape t2 = arr2.arr.shape();
+/*
+ * Determines axis ndarrays are concatenated by
+ * Used to verify concat/concat backwards operator
+ */
+int GetDim(TShape input_shape, TShape output_shape) {
+  CHECK(input_shape.Size() != output_shape.Size());
+  for (size_t i = 0; i < input_shape.ndim(); i++) {
+    if (input_shape[i] != output_shape[i])
+      return i;
+  }
+  return -1;
+}
 
-  printf("Verifying: %s (", arr1.desc.c_str());
-  for (size_t i = 0; i < t1.ndim(); i++)
-    printf("%ld, ", t1[i]);
-  printf(") with %s (", arr2.desc.c_str());
-  for (size_t i = 0; i < t2.ndim(); i++)
-    printf("%ld, ", t2[i]);
-  printf(")\n");
+/*
+ * Calculates the size of continuous block of array inside arger concatenated array
+ * Used to verify concat/concat backwards operator
+ */
+int GetBlockSize(TShape shape, int dim) {
+  int block_size = 1;
+  for (int i = shape.ndim() - 1; i >= dim; i--)
+    block_size *= shape[i];
+  return block_size;
+}
+
+void VerifyConcatResult(const std::vector<NDArray *> &in_arrs,
+                        const std::vector<NDArray *> &out_arrs) {
+  int num_inputs = in_arrs.size();
+  int input_size = in_arrs[0]->shape().Size();
+  TShape input_shape = in_arrs[0]->shape();
+  NDArray output = out_arrs[0]->Reorder2Default();
+  size_t total_size = output.shape().Size();
+  EXPECT_EQ(input_size * num_inputs, total_size);
+  mshadow::default_real_t *out_data = output.data().dptr<mshadow::default_real_t>();
+
+  int dim = GetDim(input_shape, output.shape());
+  int block_size = GetBlockSize(input_shape, dim);
+  int num_blocks = input_size / block_size;
+  for (size_t input_num = 0; input_num < num_inputs; input_num++) {
+    NDArray tmp = in_arrs[input_num]->Reorder2Default();
+    mshadow::default_real_t* data = tmp.data().dptr<mshadow::default_real_t>();
+    for (size_t block_num = 0; block_num < num_blocks; block_num++) {
+      for (size_t i = 0; i < block_size; i++)
+        ASSERT_EQ(data[block_num * block_size + i],
+                  out_data[(block_num * num_inputs + input_num) * block_size + i]);
+    }
+  }
+}
+
+void VerifyConcatBackwardsResult(const std::vector<NDArray *> &in_arrs,
+                        const std::vector<NDArray *> &out_arrs) {
+  // in_arrs is larger array, out_arr is ammler
+  int num_inputs = out_arrs.size();
+  int input_size = out_arrs[0]->shape().Size();
+  TShape input_shape = out_arrs[0]->shape();
+  NDArray output = in_arrs[0]->Reorder2Default();
+  size_t total_size = output.shape().Size();
+  EXPECT_EQ(input_size * num_inputs, total_size);
+  mshadow::default_real_t *out_data = output.data().dptr<mshadow::default_real_t>();
+
+  int dim = GetDim(input_shape, output.shape());
+  int block_size = GetBlockSize(input_shape, dim);
+  int num_blocks = input_size / block_size;
+  for (size_t input_num = 0; input_num < num_inputs; input_num++) {
+    NDArray tmp = out_arrs[input_num]->Reorder2Default();
+    mshadow::default_real_t* data = tmp.data().dptr<mshadow::default_real_t>();
+    for (size_t block_num = 0; block_num < num_blocks; block_num++) {
+      for (size_t i = 0; i < block_size; i++)
+        ASSERT_EQ(data[block_num * block_size + i],
+                  out_data[(block_num * num_inputs + input_num) * block_size + i]);
+    }
+  }
 }
 
 void VerifyAddRequest(const std::vector<NDArray*> &in_arrs,
@@ -703,11 +860,11 @@ TEST(MKLDNN_NDArray, CopyFrom) {
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
 
   std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  for (auto in_arr : in_arrs) {
+  for (auto &in_arr : in_arrs) {
+    if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
+      continue;
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
-    for (auto out_arr : out_arrs) {
-      if (in_arr.arr.IsMKLDNNData() && in_arr.arr.IsView())
-        in_arr.arr = in_arr.arr.Reorder2Default();
+    for (auto &out_arr : out_arrs) {
       const mkldnn::memory *mem = in_arr.arr.GetMKLDNNData();
       out_arr.arr.CopyFrom(*mem);
       MKLDNNStream::Get()->Submit();
@@ -728,29 +885,30 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
 
   std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
-  for (auto in_arr : in_arrs) {
-    for (auto dispatch : dispatches) {
-      std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
-      for (auto out_arr : out_arrs) {
-        for (int i = 0; i < attrs.num_inputs; i++)
-          inputs[i] = &in_arr.arr;
+  for (auto &in_arr : in_arrs) {
+    for (auto &dispatch : dispatches) {
+      std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
+      for (int i = 0; i < attrs.num_outputs; i++)
+        out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds);
+      for (int i = 0; i < attrs.num_inputs; i++)
+        inputs[i] = &in_arr.arr;
+      for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
         for (int i = 0; i < attrs.num_outputs; i++) {
           req[i] = kWriteTo;
-          outputs[i] = &out_arr.arr;
+          outputs[i] = &out_arrs[i][output_i].arr;
         }
-        PrintVerifyMsg(in_arr, out_arr);
+        PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
         Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
                                     outputs, req, dispatch, mxnet::OpStatePtr());
-        for (auto output : outputs)
-          output->WaitToRead();
+        Engine::Get()->WaitForAll();
         verify_fn(inputs, outputs);
       }
     }
   }
 
-  for (auto dispatch : dispatches) {
+  for (auto &dispatch : dispatches) {
     in_arrs = GetTestInputArrays();
-    for (auto arr : in_arrs) {
+    for (auto &arr : in_arrs) {
       // If the array is a view, we shouldn't write data to it.
       if (arr.arr.IsView())
         continue;
@@ -764,8 +922,7 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
       PrintVerifyMsg(orig, arr);
       Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs, outputs, req,
                                   dispatch, mxnet::OpStatePtr());
-      for (auto output : outputs)
-        output->WaitToRead();
+      Engine::Get()->WaitForAll();
       std::vector<NDArray *> orig_inputs(attrs.num_inputs);
       for (int i = 0; i < attrs.num_inputs; i++)
         orig_inputs[i] = &orig.arr;
@@ -774,6 +931,57 @@ void TestOp(const OpAttrs &attrs, VerifyFunc verify_fn) {
   }
 }
 
+void TestConcatOp(const OpAttrs &attrs, VerifyFunc verify_fn,
+            bool backwards = false) {
+  std::vector<NDArray*> inputs(attrs.num_inputs);
+  std::vector<NDArray*> outputs(attrs.num_outputs);
+  std::vector<OpReqType> req(attrs.num_outputs);
+  std::vector<DispatchMode> dispatches = attrs.dispatches;
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
+
+  // concat backwards uses scaled up inputs
+  if (backwards) {
+    std::string str_dim = const_cast<OpAttrs&>(attrs).attrs.dict["dim"];
+    int dim = std::stoi(str_dim);
+    in_arrs = GetTestInputArrays(false, attrs.num_outputs, dim);
+  }
+
+  for (auto &in_arr : in_arrs) {
+    for (auto &dispatch : dispatches) {
+      std::vector<std::vector<NDArrayAttrs>> out_arrs(attrs.num_outputs);
+
+      std::string str_dim = const_cast<OpAttrs&>(attrs).attrs.dict["dim"];
+      int dim = std::stoi(str_dim);
+      if (dim >= in_arr.arr.shape().ndim())
+        continue;
+      float scale = backwards ? 1 / static_cast<float>(attrs.num_outputs) :
+          static_cast<float>(attrs.num_inputs);
+      for (int i = 0; i < attrs.num_outputs; i++)
+        out_arrs[i] = GetTestOutputArrays(in_arr.arr.shape(), pds, scale, dim);
+
+      for (int i = 0; i < attrs.num_inputs; i++)
+        inputs[i] = &in_arr.arr;
+
+      for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+        for (int i = 0; i < attrs.num_outputs; i++) {
+          req[i] = kWriteTo;
+          outputs[i] = &out_arrs[i][output_i].arr;
+        }
+
+        PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+        Imperative::Get()->InvokeOp(Context(), attrs.attrs, inputs,
+                                    outputs, req, dispatch, mxnet::OpStatePtr());
+        Engine::Get()->WaitForAll();
+        verify_fn(inputs, outputs);
+      }
+    }
+  }
+}
+
 TEST(IMPERATIVE, CopyOp) {
   OpAttrs attrs = GetCopyOp();
   TestOp(attrs, VerifyCopyResult);
@@ -804,6 +1012,24 @@ TEST(IMPERATIVE, SumBackwardsOp) {
   TestOp(attrs, VerifySumBackwardsResult);
 }
 
+TEST(IMPERATIVE, ConcatOp) {
+  for (int num_inputs = 2; num_inputs < 4; num_inputs++) {
+    for (int dim = 0; dim < 5; dim++) {
+      OpAttrs attrs = GetConcatOp(num_inputs, dim);
+      TestConcatOp(attrs, VerifyConcatResult);
+    }
+  }
+}
+
+TEST(IMPERATIVE, ConcatBackwardsOp) {
+  for (int num_inputs = 2; num_inputs < 4; num_inputs++) {
+    for (int dim = 0; dim < 5; dim++) {
+      OpAttrs attrs = GetConcatBackwardsOp(num_inputs, dim);
+      TestConcatOp(attrs, VerifyConcatBackwardsResult, true);
+    }
+  }
+}
+
 TEST(MKLDNN_BASE, MKLDNNSum) {
   std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays();
   std::vector<NDArrayAttrs> in_arrs2 = GetTestInputArrays(true);
@@ -819,7 +1045,7 @@ TEST(MKLDNN_BASE, MKLDNNSum) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
-    for (auto out_arr : out_arrs) {
+    for (auto &out_arr : out_arrs) {
       auto in_mem1 = in_arr.arr.GetMKLDNNData();
       auto in_mem2 = in_arr2.arr.GetMKLDNNData();
       if (out_arr.arr.IsView())
@@ -870,7 +1096,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
-    for (auto out_arr : out_arrs) {
+    for (auto &out_arr : out_arrs) {
       auto in_mem = in_arr.arr.GetMKLDNNData();
       auto in_mem2 = in_arr2.arr.GetMKLDNNData();
       NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());
@@ -919,7 +1145,7 @@ TEST(MKLDNN_BASE, CreateMKLDNNMem) {
       continue;
     }
     std::vector<NDArrayAttrs> out_arrs = GetTestOutputArrays(in_arr.arr.shape(), pds);
-    for (auto out_arr : out_arrs) {
+    for (auto &out_arr : out_arrs) {
       auto in_mem = in_arr.arr.GetMKLDNNData();
       auto in_mem2 = in_arr2.arr.GetMKLDNNData();
       NDArray orig_output = out_arr.arr.Copy(out_arr.arr.ctx());


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services