You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by pa...@apache.org on 2019/10/13 09:15:40 UTC

[incubator-mxnet] branch mkldnn-v1.0 updated: [mkldnn-v1.0] Add MKL-DNN int8 fc (#16457)

This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new 2b363a0  [mkldnn-v1.0] Add MKL-DNN int8 fc (#16457)
2b363a0 is described below

commit 2b363a0c11cfeb272fb99a3d79d09c51a9241d22
Author: Wuxun Zhang <wu...@intel.com>
AuthorDate: Sun Oct 13 17:15:12 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN int8 fc (#16457)
    
    * Add mkldnn_v1.0 int8 fc
    
    * trigger CI
    
    * trigger CI
---
 src/operator/nn/mkldnn/mkldnn_fully_connected.cc   | 10 +++----
 .../mkldnn/mkldnn_quantized_fully_connected.cc     | 31 +++++++++++++---------
 .../quantization/mkldnn/mkldnn_quantized_ops-inl.h |  2 +-
 .../quantization/quantized_fully_connected.cc      | 10 +++----
 tests/python/quantization/test_quantization.py     |  2 +-
 5 files changed, 31 insertions(+), 24 deletions(-)

diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
index 80eb2d6..24d31dd 100644
--- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
+++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc
@@ -216,7 +216,7 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
   auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
                                  fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);
 
-  std::unordered_map<int, mkldnn::memory> args = {
+  mkldnn_args_map_t args = {
       {MKLDNN_ARG_SRC, *data_mem},
       {MKLDNN_ARG_WEIGHTS, *weight_mem},
       {MKLDNN_ARG_DST, *out_mem.second},
@@ -224,7 +224,7 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
   if (!full_param.default_param.no_bias) {
     auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
         fwd->fwd_pd.bias_desc());
-    args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
+    args[MKLDNN_ARG_BIAS] = *bias_mem;
   }
   MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
   CommitOutput(out_data[fullc::kOut], out_mem);
@@ -298,7 +298,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
     auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
                                        ipBwdData_pd.diff_src_desc(),
                                        req[fullc::kData]);
-    std::unordered_map<int, mkldnn::memory> args = {
+    mkldnn_args_map_t args = {
       {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
       {MKLDNN_ARG_WEIGHTS, *weight_mem},
       {MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
@@ -317,7 +317,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
     auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
                                                  ipBwdWeights_pd.diff_weights_desc(),
                                                  req[fullc::kWeight]);
-    std::unordered_map<int, mkldnn::memory> args = {
+    mkldnn_args_map_t args = {
       {MKLDNN_ARG_DIFF_DST, *out_grad_mem},
       {MKLDNN_ARG_SRC, *data_mem},
       {MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
@@ -328,7 +328,7 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
       in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
                                      ipBwdWeights_pd.diff_bias_desc(),
                                      req[fullc::kBias]);
-      args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
+      args[MKLDNN_ARG_DIFF_BIAS] = *in_grad_bias.second;
     }
     MKLDNNStream::Get()->RegisterPrimArgs(
         mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
index aca129a..1e35edd 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc
@@ -24,7 +24,7 @@
  * \author Ciyong Chen
  */
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "../../nn/mkldnn/mkldnn_fully_connected-inl.h"
 #include "../quantization_utils.h"
 
@@ -89,28 +89,35 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
   auto &fwd = GetFCFwd(param, is_train, data, weight,
       param.no_bias ? nullptr : &quantized_bias, out_md);
 
-  auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
+  auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_desc());
   const mkldnn::memory *weight_mem = nullptr;
 
   if (weight.IsDefaultData()) {
     // We also need to modify the layout on the original weight array.
     // Don't switch below sequence because naive engine will executes
     // pushAsync synchronously.
-    weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
-    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
+    weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_desc());
+    weight_mem = GetWeights(weight, fwd.fwd_pd.weights_desc(), 1);
   } else {
     weight_mem = weight.GetMKLDNNData();
-    CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
+    CHECK(weight_mem->get_desc() == fwd.fwd_pd.weights_desc());
   }
-  auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_primitive_desc(),
+  auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], fwd.fwd_pd.dst_desc(),
                                  req[fullc::kOut]);
-  const mkldnn::memory *bias_mem = nullptr;
-  if (!param.no_bias)
-    bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());
 
-  fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
-  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+  mkldnn_args_map_t args = {
+      {MKLDNN_ARG_SRC, *data_mem},
+      {MKLDNN_ARG_WEIGHTS, *weight_mem},
+      {MKLDNN_ARG_DST, *out_mem.second},
+  };
+
+  const mkldnn::memory *bias_mem = nullptr;
+  if (!param.no_bias) {
+    bias_mem = quantized_bias.GetMKLDNNDataReorder(fwd.fwd_pd.bias_desc());
+    args[MKLDNN_ARG_BIAS] = *bias_mem;
+  }
 
+  MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), args);
   CommitOutput(out_data[fullc::kOut], out_mem);
   MKLDNNStream::Get()->Submit();
 }
@@ -118,4 +125,4 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
 }  // namespace op
 }  // namespace mxnet
 
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h
index 88d77c8..6de26e5 100644
--- a/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_ops-inl.h
@@ -27,7 +27,7 @@
 #ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
 #define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZED_OPS_INL_H_
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 
 #include <mxnet/ndarray.h>
 #include <vector>
diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc
index 4c9d9d2..89e0235 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -26,7 +26,7 @@
 #include <vector>
 #include "quantization_utils.h"
 #include "../nn/fully_connected-inl.h"
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "../nn/mkldnn/mkldnn_fully_connected-inl.h"
 #include "mkldnn/mkldnn_quantized_ops-inl.h"
 #endif
@@ -94,7 +94,7 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_type->size(), num_inputs * 3);
   CHECK_EQ(out_type->size(), 3U);
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
       << "QuantizedFullyConnected only supports int8/uint8 input, while "
       << in_type->at(0) << " is given.";
@@ -124,7 +124,7 @@ bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), num_inputs * 3);
   CHECK_EQ(out_attrs->size(), 3U);
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
   return MKLDNNStorageType(attrs, dev_mask, true,
                            dispatch_mode, in_attrs, out_attrs);
 #else
@@ -292,7 +292,7 @@ void QuantizedFullyConnectedForwardCPU(const nnvm::NodeAttrs& attrs,
 #endif
 }
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 void QuantizedFullyConnectedForwardExCPU(const nnvm::NodeAttrs &attrs,
                                          const OpContext &ctx,
                                          const std::vector<NDArray> &in_data,
@@ -341,7 +341,7 @@ and max thresholds representing the threholds for quantizing the float32 output
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
 .set_attr<FCompute>("FCompute<cpu>", QuantizedFullyConnectedForwardCPU)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", QuantizedFullyConnectedForwardExCPU)
 #endif
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index 9721460..6d6ba41 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -407,7 +407,7 @@ def test_quantized_pooling():
 def test_quantized_fc():
     def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
         if is_test_for_native_cpu():
-            hasMKL = False;
+            hasMKL = False
             for key in os.environ.keys():
                 if operator.eq(key, "BUILD_TAG"):
                     if os.environ['BUILD_TAG'].find("MKL") != -1: