You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/11/15 02:47:09 UTC

[incubator-mxnet] branch master updated: adding unit test for MKLDNN FullyConnected operator (#12985)

This is an automated email from the ASF dual-hosted git repository.

anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new cf991ff  adding unit test for MKLDNN FullyConnected operator (#12985)
cf991ff is described below

commit cf991ff3dcdd20d99ec996e5f501ba65f9ca65b4
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Wed Nov 14 18:46:54 2018 -0800

    adding unit test for MKLDNN FullyConnected operator (#12985)
    
    * adding unit test for MKLDNN FullyConnected operator
    
    * removing mkldnn filter
    
    * removing mkldnn filter
---
 tests/cpp/operator/mkldnn_operator_test.cc | 155 +++++++++++++++++++++++++++++
 1 file changed, 155 insertions(+)

diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index 21b257e..9e30cd8 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -213,6 +213,36 @@ OpAttrs GetLRNBackwardsOp() {
   return attrs;
 }
 
+OpAttrs GetFullyConnectedOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("FullyConnected");
+  attrs.attrs.dict.insert({"num_hidden" , "20"});
+  attrs.num_inputs = 3;
+  attrs.num_outputs = 1;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  attrs.input_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped;
+  attrs.output_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN |
+      ArrayTypes::NormalReshaped |
+      ArrayTypes::MKLDNNReshaped;
+  return attrs;
+}
+
+OpAttrs GetFullyConnectedBackwardsOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_FullyConnected");
+  attrs.attrs.dict.insert({"num_hidden" , "20"});
+  attrs.num_inputs = 3;
+  attrs.num_outputs = 3;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  return attrs;
+}
+
 void AssertEqual(const std::vector<NDArray *> &in_arrs,
                  const std::vector<NDArray *> &out_arrs,
                  float rtol = 1e-5, float atol = 1e-8) {
@@ -557,6 +587,125 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
   }
 }
 
+// Computes second dimension of FC weight matrix based on input shape
+uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
+  uint32_t dim = 1;
+  for (int i = 1; i < arr.ndim(); i++) {
+    dim *= arr[i];
+  }
+  return dim;
+}
+
+void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+
+  std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+  std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+  std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+
+  std::vector<OpReqType> req(forward_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true);
+  std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+  std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);
+
+  std::string str_hid = const_cast<OpAttrs&>(forward_attrs).attrs.dict["num_hidden"];
+  int num_hid = std::stoi(str_hid);
+
+  if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
+    for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+      auto in_arr = in_arrs[i1];
+      auto in_shape = in_arr.arr.shape();
+      if (in_shape.ndim() < 2)
+        continue;
+
+      nnvm::TShape wt_shape(2);
+      wt_shape[0] = num_hid;
+      wt_shape[1] = GetFCWeightDim2(in_shape);
+      NDArray weights(wt_shape, Context());
+      InitDefaultArray(&weights, false);
+
+      nnvm::TShape bias_shape(1);
+      bias_shape[0] = num_hid;
+      NDArray bias(bias_shape, Context());
+      InitDefaultArray(&bias, false);
+
+      inputs[0] = &in_arr.arr;
+      inputs[1] = &weights;
+      inputs[2] = &bias;
+
+      nnvm::TShape out_shape(2);
+      out_shape[0] = in_shape[0];
+      out_shape[1] = num_hid;
+
+      for (int i = 0; i < forward_attrs.num_outputs; i++) {
+        out_arrs[i] =
+            GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
+        ex_out_arrs[i] =
+            GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
+      }
+
+      for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+        for (int i = 0; i < forward_attrs.num_outputs; i++) {
+          req[i] = kWriteTo;
+          outputs[i] = &out_arrs[i][output_i].arr;
+          ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
+        }
+        Imperative::Get()->set_is_training(true);
+
+        PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+        Imperative::Get()->InvokeOp(
+            Context(), forward_attrs.attrs, inputs, outputs, req,
+            DispatchMode::kFCompute, mxnet::OpStatePtr());
+        Imperative::Get()->InvokeOp(
+            Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+            DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+        Engine::Get()->WaitForAll();
+        AssertEqual(outputs, ex_outputs);
+
+        // backwards test performed same time since output needed
+        backwards_input[0] = outputs[0];  // output grad
+        backwards_input[1] = inputs[0];  // input
+        backwards_input[2] = inputs[1];  // weights
+
+        auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
+        NDArray back_weights(wt_shape, Context());
+        NDArray back_bias(bias_shape, Context());
+        backwards_outputs[0] = &tmp_output.arr;
+        backwards_outputs[1] = &back_weights;
+        backwards_outputs[2] = &back_bias;
+
+        auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
+        NDArray back_ex_weights(wt_shape, Context());
+        NDArray back_ex_bias(bias_shape, Context());
+        backwards_ex_outputs[0] = &tmp_output2.arr;
+        backwards_ex_outputs[1] = &back_ex_weights;
+        backwards_ex_outputs[2] = &back_ex_bias;
+
+        for (int i = 0; i < backwards_attrs.num_outputs; i++)
+          back_req[i] = kWriteTo;
+
+        std::cout << "Backwards: ";
+        PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
+        Imperative::Get()->InvokeOp(
+            Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+            back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+        Imperative::Get()->InvokeOp(
+            Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
+            back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+        Engine::Get()->WaitForAll();
+        AssertEqual(backwards_outputs, backwards_ex_outputs);
+      }
+    }
+  }
+}
+
 void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
   std::vector<NDArray*> inputs(forward_attrs.num_inputs);
   std::vector<NDArray*> outputs(forward_attrs.num_outputs);
@@ -717,6 +866,12 @@ TEST(IMPERATIVE, LRNOp) {
   TestOpEx(forward_attrs, backwards_attrs);
 }
 
+TEST(IMPERATIVE, FullyConnectedOp) {
+  OpAttrs forward_attrs = GetFullyConnectedOp();
+  OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();
+  TestFullyConnectedOp(forward_attrs, backwards_attrs);
+}
+
 TEST(IMPERATIVE, PoolingOp) {
   for (int dim = 2; dim < 4; dim++) {
     for (int kernel = 1; kernel < 4; kernel++) {