You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/20 06:26:38 UTC

[GitHub] anirudh2290 closed pull request #12884: adding unittest for MKLDNN Softmax operator

anirudh2290 closed pull request #12884: adding unittest for MKLDNN Softmax operator
URL: https://github.com/apache/incubator-mxnet/pull/12884
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/tests/cpp/include/test_mkldnn.h b/tests/cpp/include/test_mkldnn.h
index 31fe5c7d7bb..67371b4ec3e 100644
--- a/tests/cpp/include/test_mkldnn.h
+++ b/tests/cpp/include/test_mkldnn.h
@@ -227,6 +227,7 @@ struct OpAttrs {
   nnvm::NodeAttrs attrs;
   std::vector<DispatchMode> dispatches;
   std::set<OpReqType> requests;
+  std::unordered_set<int> accept_dims;
   int num_inputs;
   int num_outputs;
   int input_types;
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index 9e30cd8fa62..55cb887b41b 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -188,7 +188,7 @@ OpAttrs GetLRNOp() {
   attrs.num_outputs = 2;
   attrs.attrs.dict.insert({"nsize" , "3"});
   attrs.attrs.op->attr_parser(&attrs.attrs);
-  attrs.dispatches.resize(2);
+  attrs.accept_dims.insert(4);
   attrs.requests.insert(OpReqType::kWriteTo);
   attrs.input_types = ArrayTypes::Normal |
       ArrayTypes::MKLDNN |
@@ -213,6 +213,26 @@ OpAttrs GetLRNBackwardsOp() {
   return attrs;
 }
 
+OpAttrs GetSoftmaxOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("softmax");
+  attrs.num_inputs = 1;
+  attrs.num_outputs = 1;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.accept_dims.insert({1, 2, 3, 4, 5});
+  attrs.requests.insert(OpReqType::kWriteTo);
+  attrs.requests.insert(OpReqType::kWriteInplace);
+  attrs.input_types = ArrayTypes::Normal |
+    ArrayTypes::MKLDNN |
+    ArrayTypes::NormalReshaped |
+    ArrayTypes::MKLDNNReshaped;
+  attrs.output_types = ArrayTypes::Normal |
+    ArrayTypes::MKLDNN |
+    ArrayTypes::NormalReshaped |
+    ArrayTypes::MKLDNNReshaped;
+  return attrs;
+}
+
 OpAttrs GetFullyConnectedOp() {
   OpAttrs attrs;
   attrs.attrs.op = Op::Get("FullyConnected");
@@ -498,19 +518,51 @@ void TestConcatOp(const OpAttrs &attrs, VerifyFunc verify_fn,
   }
 }
 
-// compares output of fcompute with fcomputex
-void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
-  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
-  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
-  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
-
+void TestOpExBackward(const OpAttrs &forward_attrs,
+                      const OpAttrs &backwards_attrs,
+                      const OpReqType &req,
+                      const std::vector<NDArray*> &inputs,
+                      const std::vector<NDArray*> &outputs,
+                      const NDArrayAttrs &in_arr,
+                      const NDArrayAttrs &out_arr) {
   std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
   std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
   std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+  if (req == kWriteTo) {
+    // backwards test performed same time since output needed
+    backwards_input[0] = outputs[0];  // output grad
+    backwards_input[1] = inputs[0];  // input
+    backwards_input[2] = outputs[1];  // out norm
+
+    auto tmp_output = in_arr.arr;
+    backwards_outputs[0] = &tmp_output;
+    backwards_ex_outputs[0] = &tmp_output;
+
+    for (int i = 0; i < backwards_attrs.num_outputs; i++)
+      back_req[i] = kWriteTo;
+
+    std::cout << "Backwards: ";
+    PrintVerifyMsg(out_arr, in_arr);
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+        back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
+        back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+    Engine::Get()->WaitForAll();
+    AssertEqual(backwards_outputs, backwards_ex_outputs);
+  }
+}
 
 
+// compares output of fcompute with fcomputex
+void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
   std::vector<OpReqType> req(forward_attrs.num_outputs);
-  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
 
   TestArrayShapes tas = GetTestArrayShapes();
   std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
@@ -523,8 +575,9 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
     for (int i1 = 0; i1 < in_arrs.size(); i1++) {
       auto in_arr = in_arrs[i1];
 
-      // TODO(alex): (MXNET-845) Remove when MKLDNN supports other dims
-      if (in_arr.arr.shape().ndim() != 4)
+      CHECK_NE(forward_attrs.accept_dims.size(), 0);
+      if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
+          forward_attrs.accept_dims.end())
         continue;
 
       for (int i = 0; i < forward_attrs.num_outputs; i++) {
@@ -538,9 +591,6 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
         inputs[i] = &in_arr.arr;
 
       for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
-        if (out_arrs[0][output_i].arr.IsMKLDNNData())
-          continue;
-
         for (int i = 0; i < forward_attrs.num_outputs; i++) {
           req[i] = kWriteTo;
           outputs[i] = &out_arrs[i][output_i].arr;
@@ -558,31 +608,41 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
         Engine::Get()->WaitForAll();
         AssertEqual(outputs, ex_outputs);
 
-        // backwards test performed same time since output needed
-        backwards_input[0] = outputs[0];  // output grad
-        backwards_input[1] = inputs[0];  // input
-        backwards_input[2] = outputs[1];  // out norm
+        if (!backwards_attrs.requests.empty()) {
+          TestOpExBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
+                           inputs, outputs, in_arr, out_arrs[0][output_i]);
+        }
+      }
+    }
+  }
 
-        auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
-        backwards_outputs[0] = &tmp_output.arr;
+  if (forward_attrs.requests.find(OpReqType::kWriteInplace) != forward_attrs.requests.end()) {
+    for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+      auto in_arr = in_arrs[i1];
 
-        auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
-        backwards_ex_outputs[0] = &tmp_output2.arr;
+      // If the array is a view, we shouldn't write data to it.
+      if (in_arr.arr.IsView())
+          continue;
 
-        for (int i = 0; i < backwards_attrs.num_outputs; i++)
-          back_req[i] = kWriteTo;
+      NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
+      for (int i = 0; i < forward_attrs.num_inputs; i++)
+        inputs[i] = &in_arr.arr;
 
-        std::cout << "Backwards: ";
-        PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
-        Imperative::Get()->InvokeOp(
-            Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
-            back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
-        Imperative::Get()->InvokeOp(
-            Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
-            back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
-        Engine::Get()->WaitForAll();
-        AssertEqual(backwards_outputs, backwards_ex_outputs);
+      for (int i = 0; i < forward_attrs.num_outputs; i++) {
+        req[i] = kWriteInplace;
+        outputs[i] = &in_arr.arr;
+        ex_outputs[i] = &in_arr.arr;
       }
+      Imperative::Get()->set_is_training(true);
+      PrintVerifyMsg(orig, in_arr);
+      Imperative::Get()->InvokeOp(
+          Context(), forward_attrs.attrs, inputs, outputs, req,
+          DispatchMode::kFCompute, mxnet::OpStatePtr());
+      Imperative::Get()->InvokeOp(
+          Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+          DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+      Engine::Get()->WaitForAll();
+      AssertEqual(outputs, ex_outputs);
     }
   }
 }
@@ -866,6 +926,12 @@ TEST(IMPERATIVE, LRNOp) {
   TestOpEx(forward_attrs, backwards_attrs);
 }
 
+TEST(IMPERATIVE, SoftmaxOp) {
+  OpAttrs forward_attrs = GetSoftmaxOp();
+  OpAttrs backwards_attrs;
+  TestOpEx(forward_attrs, backwards_attrs);
+}
+
 TEST(IMPERATIVE, FullyConnectedOp) {
   OpAttrs forward_attrs = GetFullyConnectedOp();
   OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services