You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/14 21:48:10 UTC

[incubator-mxnet] Diff for: [GitHub] azai91 closed pull request #13084: Test/mkldnn batch norm op

diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 6254a1e1866..e7108241257 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -380,12 +380,20 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
 }
 
 #if MXNET_USE_MKLDNN == 1
-static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
-  TShape shape = input.shape();
-  return SupportMKLDNN(input) && shape.ndim() == 4
+static inline bool SupportMKLDNNBN(const std::vector<NDArray> &inputs,
+    const BatchNormParam &param) {
+  TShape shape = inputs[0].shape();
+  bool params_valid = shape.ndim() == 4
       && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
       && shape[param.axis] % 8 == 0
       && !mxnet::op::batchnorm::disable_mkl;
+  bool inputs_valid = SupportMKLDNN(inputs[0]);
+  for (size_t i = 1; i < inputs.size(); i++) {
+    if (inputs[i].IsMKLDNNData()) {
+      inputs_valid = false;
+    }
+  }
+  return  params_valid && inputs_valid;
 }
 
 void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
@@ -396,7 +404,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
   CHECK_EQ(inputs.size(), 5U);
   const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
   // MKLDNN batchnorm only works well on the special MKLDNN layout.
-  if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {
+  if (SupportMKLDNNBN(inputs, param) && inputs[0].IsMKLDNNData()) {
     std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
     std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
 
@@ -420,7 +428,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
 
   TShape shape = inputs[0].shape();
   // MKLDNN batchnorm only works well on the special MKLDNN layout.
-  if (SupportMKLDNNBN(inputs[0], param)
+  if (SupportMKLDNNBN(inputs, param)
       && (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
     std::vector<NDArray> out_grad(1);
     std::vector<NDArray> out_data(3);
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 403baaa94ab..7638e8bcf52 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -216,14 +216,24 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
   auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
   const NDArray &out  = out_data[batchnorm::kOut];
 
+  auto gamma_buffer = in_data[batchnorm::kGamma];
+  if (gamma_buffer.IsMKLDNNData()) {
+    gamma_buffer = gamma_buffer.Reorder2Default();
+  }
+
+  auto beta_buffer = in_data[batchnorm::kBeta];
+  if (beta_buffer.IsMKLDNNData()) {
+    beta_buffer = beta_buffer.Reorder2Default();
+  }
+
   // for output memory
   auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());
 
   // mxnet will always use scale shift.
   // But if fix_gamma is true, then all scale elements will be set to 1.0f
   if (flags & use_scale_shift) {
-    const NDArray &gamma    = in_data[batchnorm::kGamma];
-    const NDArray &beta     = in_data[batchnorm::kBeta];
+    const NDArray &gamma    = gamma_buffer;
+    const NDArray &beta     = beta_buffer;
     CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
     CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
 
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index a500d4c2df6..f1deda57dd5 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in
   return attrs;
 }
 
+OpAttrs GetBNOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("BatchNorm");
+  attrs.num_inputs = 5;
+  attrs.num_outputs = 3;
+  attrs.accept_dims.insert(4);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.input_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN;
+  attrs.output_types = ArrayTypes::Normal |
+      ArrayTypes::MKLDNN;
+  return attrs;
+}
+
+OpAttrs GetBNBackwardOp() {
+  OpAttrs attrs;
+  attrs.attrs.op = Op::Get("_backward_BatchNorm");
+  attrs.num_inputs = 8;
+  attrs.num_outputs = 3;
+  attrs.attrs.op->attr_parser(&attrs.attrs);
+  attrs.requests.insert(OpReqType::kWriteTo);
+  return attrs;
+}
+
 void AssertEqual(const std::vector<NDArray *> &in_arrs,
                  const std::vector<NDArray *> &out_arrs,
                  float rtol = 1e-5, float atol = 1e-8) {
@@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
 
       // If the array is a view, we shouldn't write data to it.
       if (in_arr.arr.IsView())
-          continue;
+        continue;
 
       NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
       for (int i = 0; i < forward_attrs.num_inputs; i++)
@@ -735,6 +760,128 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
   }
 }
 
+
+void TestOpExBNBackward(const OpAttrs &forward_attrs,
+                      const OpAttrs &backwards_attrs,
+                      const OpReqType &req,
+                      const std::vector<NDArray*> &inputs,
+                      const std::vector<NDArray*> &outputs,
+                      const NDArrayAttrs &in_arr,
+                      const NDArrayAttrs &out_arr) {
+  std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+
+  std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
+  std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);
+
+  std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+  std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+  std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+  if (req == kWriteTo) {
+    backwards_input[0] = outputs[0];  // output grad
+    backwards_input[1] = outputs[1];  // mean
+    backwards_input[2] = outputs[2];  // var
+    backwards_input[3] = inputs[0];  // data
+    backwards_input[4] = inputs[1];  // gamma
+    backwards_input[5] = inputs[2];  // beta
+    backwards_input[6] = inputs[3];  // moving mean
+    backwards_input[7] = inputs[4];  // moving var
+
+
+    for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
+      auto tmp_output = in_arr.arr;
+      backwards_buffer.emplace_back(tmp_output.Copy(Context()));
+      backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
+      backwards_outputs[i] = &backwards_buffer.back();
+      backwards_ex_outputs[i] = &backwards_buffer2.back();
+      Engine::Get()->WaitForAll();
+    }
+
+
+    for (int i = 0; i < backwards_attrs.num_outputs; i++)
+      back_req[i] = kWriteTo;
+
+    std::cout << "Backwards: ";
+    PrintVerifyMsg(out_arr, in_arr);
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+        back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+    Imperative::Get()->InvokeOp(
+        Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
+        back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+    Engine::Get()->WaitForAll();
+    AssertEqual(backwards_outputs, backwards_ex_outputs);
+  }
+}
+
+// compares output of fcompute with fcomputex
+void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+  std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+  std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
+  std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
+  std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
+  std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+  std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+  std::vector<OpReqType> req(forward_attrs.num_outputs);
+
+  TestArrayShapes tas = GetTestArrayShapes();
+  std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+  std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, false);
+  std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+  std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);
+
+  if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
+    for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+      auto in_arr = in_arrs[i1];
+
+      CHECK_NE(forward_attrs.accept_dims.size(), 0);
+      if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
+          forward_attrs.accept_dims.end())
+        continue;
+      for (int i = 0; i < forward_attrs.num_outputs; i++) {
+        out_arrs[i] =
+            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
+        ex_out_arrs[i] =
+            GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
+      }
+      for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+        inputs_buffer.clear();
+        inputs2_buffer.clear();
+
+        for (int i = 0; i < forward_attrs.num_inputs; i++) {
+          inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
+          inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
+          Engine::Get()->WaitForAll();
+          inputs[i] = &inputs_buffer.back();
+          inputs2[i] = &inputs2_buffer.back();
+        }
+        for (int i = 0; i < forward_attrs.num_outputs; i++) {
+          req[i] = kWriteTo;
+          outputs[i] = &out_arrs[i][output_i].arr;
+          ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
+        }
+        Imperative::Get()->set_is_training(true);
+
+        PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+        Imperative::Get()->InvokeOp(
+            Context(), forward_attrs.attrs, inputs, outputs, req,
+            DispatchMode::kFCompute, mxnet::OpStatePtr());
+        Imperative::Get()->InvokeOp(
+            Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
+            DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+        Engine::Get()->WaitForAll();
+        AssertEqual(outputs, ex_outputs);
+
+        if (!backwards_attrs.requests.empty()) {
+          TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
+                           inputs, outputs, in_arr, out_arrs[0][output_i]);
+        }
+      }
+    }
+  }
+}
+
 // Computes second dimension of FC weight matrix based on input shape
 uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
   uint32_t dim = 1;
@@ -1204,4 +1351,10 @@ TEST(IMPERATIVE, DeconvOp) {
   }
 }
 
+TEST(IMPERATIVE, BNOp) {
+  OpAttrs forward_attrs = GetBNOp();
+  OpAttrs backwards_attrs = GetBNBackwardOp();
+  TestOpExBN(forward_attrs, backwards_attrs);
+}
+
 #endif


With regards,
Apache Git Services