You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by an...@apache.org on 2018/11/15 02:47:09 UTC
[incubator-mxnet] branch master updated: adding unit test for
MKLDNN FullyConnected operator (#12985)
This is an automated email from the ASF dual-hosted git repository.
anirudh2290 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new cf991ff adding unit test for MKLDNN FullyConnected operator (#12985)
cf991ff is described below
commit cf991ff3dcdd20d99ec996e5f501ba65f9ca65b4
Author: Manu Seth <22...@users.noreply.github.com>
AuthorDate: Wed Nov 14 18:46:54 2018 -0800
adding unit test for MKLDNN FullyConnected operator (#12985)
* adding unit test for MKLDNN FullyConnected operator
* removing mkldnn filter
* removing mkldnn filter
---
tests/cpp/operator/mkldnn_operator_test.cc | 155 +++++++++++++++++++++++++++++
1 file changed, 155 insertions(+)
diff --git a/tests/cpp/operator/mkldnn_operator_test.cc b/tests/cpp/operator/mkldnn_operator_test.cc
index 21b257e..9e30cd8 100644
--- a/tests/cpp/operator/mkldnn_operator_test.cc
+++ b/tests/cpp/operator/mkldnn_operator_test.cc
@@ -213,6 +213,36 @@ OpAttrs GetLRNBackwardsOp() {
return attrs;
}
+OpAttrs GetFullyConnectedOp() {
+ OpAttrs attrs;
+ attrs.attrs.op = Op::Get("FullyConnected");
+ attrs.attrs.dict.insert({"num_hidden" , "20"});
+ attrs.num_inputs = 3;
+ attrs.num_outputs = 1;
+ attrs.attrs.op->attr_parser(&attrs.attrs);
+ attrs.requests.insert(OpReqType::kWriteTo);
+ attrs.input_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN |
+ ArrayTypes::NormalReshaped |
+ ArrayTypes::MKLDNNReshaped;
+ attrs.output_types = ArrayTypes::Normal |
+ ArrayTypes::MKLDNN |
+ ArrayTypes::NormalReshaped |
+ ArrayTypes::MKLDNNReshaped;
+ return attrs;
+}
+
+OpAttrs GetFullyConnectedBackwardsOp() {
+ OpAttrs attrs;
+ attrs.attrs.op = Op::Get("_backward_FullyConnected");
+ attrs.attrs.dict.insert({"num_hidden" , "20"});
+ attrs.num_inputs = 3;
+ attrs.num_outputs = 3;
+ attrs.attrs.op->attr_parser(&attrs.attrs);
+ attrs.requests.insert(OpReqType::kWriteTo);
+ return attrs;
+}
+
void AssertEqual(const std::vector<NDArray *> &in_arrs,
const std::vector<NDArray *> &out_arrs,
float rtol = 1e-5, float atol = 1e-8) {
@@ -557,6 +587,125 @@ void TestOpEx(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
}
}
+// Computes second dimension of FC weight matrix based on input shape
+uint32_t GetFCWeightDim2(const nnvm::TShape arr) {
+ uint32_t dim = 1;
+ for (int i = 1; i < arr.ndim(); i++) {
+ dim *= arr[i];
+ }
+ return dim;
+}
+
+void TestFullyConnectedOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
+ std::vector<NDArray*> inputs(forward_attrs.num_inputs);
+ std::vector<NDArray*> outputs(forward_attrs.num_outputs);
+ std::vector<NDArray*> ex_outputs(forward_attrs.num_outputs);
+
+ std::vector<NDArray*> backwards_input(backwards_attrs.num_inputs);
+ std::vector<NDArray*> backwards_outputs(backwards_attrs.num_outputs);
+ std::vector<NDArray*> backwards_ex_outputs(backwards_attrs.num_outputs);
+
+ std::vector<OpReqType> req(forward_attrs.num_outputs);
+ std::vector<OpReqType> back_req(backwards_attrs.num_outputs);
+
+ TestArrayShapes tas = GetTestArrayShapes();
+ std::vector<mkldnn::memory::primitive_desc> pds = tas.pds;
+
+ std::vector<NDArrayAttrs> in_arrs = GetTestInputArrays(forward_attrs.input_types, true);
+ std::vector<std::vector<NDArrayAttrs>> out_arrs(forward_attrs.num_outputs);
+ std::vector<std::vector<NDArrayAttrs>> ex_out_arrs(forward_attrs.num_outputs);
+
+ std::string str_hid = const_cast<OpAttrs&>(forward_attrs).attrs.dict["num_hidden"];
+ int num_hid = std::stoi(str_hid);
+
+ if (forward_attrs.requests.find(OpReqType::kWriteTo) != forward_attrs.requests.end()) {
+ for (int i1 = 0; i1 < in_arrs.size(); i1++) {
+ auto in_arr = in_arrs[i1];
+ auto in_shape = in_arr.arr.shape();
+ if (in_shape.ndim() < 2)
+ continue;
+
+ nnvm::TShape wt_shape(2);
+ wt_shape[0] = num_hid;
+ wt_shape[1] = GetFCWeightDim2(in_shape);
+ NDArray weights(wt_shape, Context());
+ InitDefaultArray(&weights, false);
+
+ nnvm::TShape bias_shape(1);
+ bias_shape[0] = num_hid;
+ NDArray bias(bias_shape, Context());
+ InitDefaultArray(&bias, false);
+
+ inputs[0] = &in_arr.arr;
+ inputs[1] = &weights;
+ inputs[2] = &bias;
+
+ nnvm::TShape out_shape(2);
+ out_shape[0] = in_shape[0];
+ out_shape[1] = num_hid;
+
+ for (int i = 0; i < forward_attrs.num_outputs; i++) {
+ out_arrs[i] =
+ GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
+ ex_out_arrs[i] =
+ GetTestOutputArrays(out_shape, pds, {1}, forward_attrs.output_types);
+ }
+
+ for (size_t output_i = 0; output_i < out_arrs[0].size(); output_i++) {
+ for (int i = 0; i < forward_attrs.num_outputs; i++) {
+ req[i] = kWriteTo;
+ outputs[i] = &out_arrs[i][output_i].arr;
+ ex_outputs[i] = &ex_out_arrs[i][output_i].arr;
+ }
+ Imperative::Get()->set_is_training(true);
+
+ PrintVerifyMsg(in_arr, out_arrs[0][output_i]);
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs, outputs, req,
+ DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), forward_attrs.attrs, inputs, ex_outputs, req,
+ DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(outputs, ex_outputs);
+
+ // backwards test performed same time since output needed
+ backwards_input[0] = outputs[0]; // output grad
+ backwards_input[1] = inputs[0]; // input
+ backwards_input[2] = inputs[1]; // weights
+
+ auto tmp_output = GetTestInputArrays(forward_attrs.input_types, true)[i1];
+ NDArray back_weights(wt_shape, Context());
+ NDArray back_bias(bias_shape, Context());
+ backwards_outputs[0] = &tmp_output.arr;
+ backwards_outputs[1] = &back_weights;
+ backwards_outputs[2] = &back_bias;
+
+ auto tmp_output2 = GetTestInputArrays(forward_attrs.input_types, true)[i1];
+ NDArray back_ex_weights(wt_shape, Context());
+ NDArray back_ex_bias(bias_shape, Context());
+ backwards_ex_outputs[0] = &tmp_output2.arr;
+ backwards_ex_outputs[1] = &back_ex_weights;
+ backwards_ex_outputs[2] = &back_ex_bias;
+
+ for (int i = 0; i < backwards_attrs.num_outputs; i++)
+ back_req[i] = kWriteTo;
+
+ std::cout << "Backwards: ";
+ PrintVerifyMsg(out_arrs[0][output_i], tmp_output);
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input, backwards_outputs,
+ back_req, DispatchMode::kFCompute, mxnet::OpStatePtr());
+ Imperative::Get()->InvokeOp(
+ Context(), backwards_attrs.attrs, backwards_input, backwards_ex_outputs,
+ back_req, DispatchMode::kFComputeEx, mxnet::OpStatePtr());
+ Engine::Get()->WaitForAll();
+ AssertEqual(backwards_outputs, backwards_ex_outputs);
+ }
+ }
+ }
+}
+
void TestPoolingOp(const OpAttrs &forward_attrs, const OpAttrs &backwards_attrs) {
std::vector<NDArray*> inputs(forward_attrs.num_inputs);
std::vector<NDArray*> outputs(forward_attrs.num_outputs);
@@ -717,6 +866,12 @@ TEST(IMPERATIVE, LRNOp) {
TestOpEx(forward_attrs, backwards_attrs);
}
+TEST(IMPERATIVE, FullyConnectedOp) {
+ OpAttrs forward_attrs = GetFullyConnectedOp();
+ OpAttrs backwards_attrs = GetFullyConnectedBackwardsOp();
+ TestFullyConnectedOp(forward_attrs, backwards_attrs);
+}
+
TEST(IMPERATIVE, PoolingOp) {
for (int dim = 2; dim < 4; dim++) {
for (int kernel = 1; kernel < 4; kernel++) {