You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2019/01/14 21:48:10 UTC
[incubator-mxnet] Diff for: [GitHub] azai91 closed pull request #13084:
Test/mkldnn batch norm op
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index 6254a1e1866..e7108241257 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -380,12 +380,20 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
}
#if MXNET_USE_MKLDNN == 1
-static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) {
- TShape shape = input.shape();
- return SupportMKLDNN(input) && shape.ndim() == 4
+static inline bool SupportMKLDNNBN(const std::vector<NDArray> &inputs,
+ const BatchNormParam ¶m) {
+ TShape shape = inputs[0].shape();
+ bool params_valid = shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& shape[param.axis] % 8 == 0
&& !mxnet::op::batchnorm::disable_mkl;
+ bool inputs_valid = SupportMKLDNN(inputs[0]);
+ for (size_t i = 1; i < inputs.size(); i++) {
+ if (inputs[i].IsMKLDNNData()) {
+ inputs_valid = false;
+ }
+ }
+ return params_valid && inputs_valid;
}
void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
@@ -396,7 +404,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam ¶m = nnvm::get<BatchNormParam>(attrs.parsed);
// MKLDNN batchnorm only works well on the special MKLDNN layout.
- if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {
+ if (SupportMKLDNNBN(inputs, param) && inputs[0].IsMKLDNNData()) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());
@@ -420,7 +428,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
TShape shape = inputs[0].shape();
// MKLDNN batchnorm only works well on the special MKLDNN layout.
- if (SupportMKLDNNBN(inputs[0], param)
+ if (SupportMKLDNNBN(inputs, param)
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
index 403baaa94ab..7638e8bcf52 100644
--- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
@@ -216,14 +216,24 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m,
auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
const NDArray &out = out_data[batchnorm::kOut];
+ auto gamma_buffer = in_data[batchnorm::kGamma];
+ if (gamma_buffer.IsMKLDNNData()) {
+ gamma_buffer = gamma_buffer.Reorder2Default();
+ }
+
+ auto beta_buffer = in_data[batchnorm::kBeta];
+ if (beta_buffer.IsMKLDNNData()) {
+ beta_buffer = beta_buffer.Reorder2Default();
+ }
+
// for output memory
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc());
// mxnet will always use scale shift.
// But if fix_gamma is true, then all scale elements will be set to 1.0f
if (flags & use_scale_shift) {
- const NDArray &gamma = in_data[batchnorm::kGamma];
- const NDArray &beta = in_data[batchnorm::kBeta];
+ const NDArray &gamma = gamma_buffer;
+ const NDArray &beta = beta_buffer;
CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage);
CHECK_EQ(beta.storage_type(), mxnet::kDefaultStorage);
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index a500d4c2df6..f1deda57dd5 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -347,6 +347,31 @@ OpAttrs GetDeconvBackwardOp(int kernel, int num_filters, int dim, int stride, in
return attrs;
}
+OpAttrs GetBNOp() {
+ OpAttrs attrs;
+ attrs.attrs.op = Op::Get("BatchNorm");
+ attrs.num_inputs = 5;
+ attrs.num_outputs = 3;
+ attrs.accept_dims.insert(4);
+ attrs.requests.insert(OpReqType::kWriteTo);
+ attrs.attrs.op->attr_parser(&attrs.attrs);
+ attrs.input_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN;
+ attrs.output_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN;
+ return attrs;
+}
+
+OpAttrs GetBNBackwardOp() {
+ OpAttrs attrs;
+ attrs.attrs.op = Op::Get("_backward_BatchNorm");
+ attrs.num_inputs = 8;
+ attrs.num_outputs = 3;
+ attrs.attrs.op->attr_parser(&attrs.attrs);
+ attrs.requests.insert(OpReqType::kWriteTo);
+ return attrs;
+}
+
void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
@@ -710,7 +735,7 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
// If the array is a view, we shouldn't write data to it.
if (in_arr.arr.IsView())
- continue;
+ continue;
NDArrayAttrs orig(in_arr.arr.Copy(in_arr.arr.ctx()), "InPlace Copy");
for (int i = 0; i < forward_attrs.num_inputs; i++)
@@ -735,6 +760,128 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}
+
+void TestOpExBNBackward(const OpAttrs &forward_attrs,
+ const OpAttrs &backwards_attrs,
+ const OpReqType &req,
+ const std::vector<NDArray*> &inputs,
+ const std::vector<NDArray*> &outputs,
+ const NDArrayAttrs &in_arr,
+ const NDArrayAttrs &out_arr) {
+ std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+
+ std::vector<NDArray> backwards_buffer(backwards_attrs.num_outputs);
+ std::vector<NDArray> backwards_buffer2(backwards_attrs.num_outputs);
+
+ std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+ std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+ std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+ if (req == kWriteTo) {
+ backwards_input[0] = outputs[0]; // output grad
+ backwards_input[1] = outputs[1]; // mean
+ backwards_input[2] = outputs[2]; // var
+ backwards_input[3] = inputs[0]; // data
+ backwards_input[4] = inputs[1]; // gamma
+ backwards_input[5] = inputs[2]; // beta
+ backwards_input[6] = inputs[3]; // moving mean
+ backwards_input[7] = inputs[4]; // moving var
+
+
+ for (size_t i = 0; i < backwards_attrs.num_outputs; i++) {
+ auto tmp_output = in_arr.arr;
+ backwards_buffer.emplace_back(tmp_output.Copy(Context()));
+ backwards_buffer2.emplace_back(tmp_output.Copy(Context()));
+ backwards_outputs[i] = &backwards_buffer.back();
+ backwards_ex_outputs[i] = &backwards_buffer2.back();
+ Engine::Get()->WaitForAll();
+ }
+
+
+ for (int i = 0; i < backwards_attrs.num_outputs; i++)
+ back_req[i] = kWriteTo;
+
+ std::cout << "Backwards: ";
+ PrintVerifyMsg(out_arr, in_arr);
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+ back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
+ back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(backwards_outputs, backwards_ex_outputs);
+ }
+}
+
+// compares output of fcompute with fcomputex
+void TestOpExBN(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+ std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+ std::vector<NDArray*> inputs2(forward_attrs.num_inputs);
+ std::vector<NDArray> inputs_buffer(forward_attrs.num_inputs);
+ std::vector<NDArray> inputs2_buffer(forward_attrs.num_inputs);
+ std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+ std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+ std::vector<OpReqType> req(forward_attrs.num_outputs);
+
+ TestArrayShapes tas = GetTestArrayShapes();
+ std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+ std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, false);
+ std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+ std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);
+
+ if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
+ for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+ auto in_arr = in_arrs[i1];
+
+ CHECK_NE(forward_attrs.accept_dims.size(), 0);
+ if (forward_attrs.accept_dims.find(in_arr.arr.shape().ndim()) ==
+ forward_attrs.accept_dims.end())
+ continue;
+ for (int i = 0; i < forward_attrs.num_outputs; i++) {
+ out_arrs[i] =
+ GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
+ ex_out_arrs[i] =
+ GetTestOutputArrays(in_arr.arr.shape(), pds, {1}, true, forward_attrs.output_types);
+ }
+ for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+ inputs_buffer.clear();
+ inputs2_buffer.clear();
+
+ for (int i = 0; i < forward_attrs.num_inputs; i++) {
+ inputs_buffer.emplace_back(in_arr.arr.Copy(Context()));
+ inputs2_buffer.emplace_back(in_arr.arr.Copy(Context()));
+ Engine::Get()->WaitForAll();
+ inputs[i] = &inputs_buffer.back();
+ inputs2[i] = &inputs2_buffer.back();
+ }
+ for (int i = 0; i < forward_attrs.num_outputs; i++) {
+ req[i] = kWriteTo;
+ outputs[i] = &out_arrs[i][output_i].arr;
+ ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
+ }
+ Imperative::Get()->set_is_training(true);
+
+ PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs, outputs, req,
+ DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs2, ex_outputs, req,
+ DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(outputs, ex_outputs);
+
+ if (!backwards_attrs.requests.empty()) {
+ TestOpExBNBackward(forward_attrs, backwards_attrs, OpReqType::kWriteTo,
+ inputs, outputs, in_arr, out_arrs[0][output_i]);
+ }
+ }
+ }
+ }
+}
+
// Computes second dimension of FC weight matrix based on input shape
uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
uint32_t dim = 1;
@@ -1204,4 +1351,10 @@ TEST(IMPERATIVE, DeconvOp) {
}
}
+TEST(IMPERATIVE, BNOp) {
+ OpAttrs forward_attrs = GetBNOp();
+ OpAttrs backwards_attrs = GetBNBackwardOp();
+ TestOpExBN(forward_attrs, backwards_attrs);
+}
+
#endif
With regards,
Apache Git Services