You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/12/15 05:35:10 UTC

[GitHub] TaoLv closed pull request #12922: Support Quantized Fully Connected by INT8 GEMM

TaoLv closed pull request #12922: Support Quantized Fully Connected by INT8 GEMM
URL: https://github.com/apache/incubator-mxnet/pull/12922
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc
index e334fe7ec9b..64ce73ba1cf 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -23,11 +23,17 @@
  * \brief
  * \author Ziheng Jiang, Jun Wu
 */
+#include <vector>
+#include "quantization_utils.h"
 #include "../nn/fully_connected-inl.h"
 
 namespace mxnet {
 namespace op {
 
+namespace quantized_fc {
+enum QuantizedfcOpResource {kTempSpace};
+}
+
 bool QuantizedFullyConnectedShape(const nnvm::NodeAttrs& attrs,
                                   std::vector<TShape> *in_shape,
                                   std::vector<TShape> *out_shape) {
@@ -79,6 +85,151 @@ bool QuantizedFullyConnectedType(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+bool QuantizedFullyConnectedStorageType(const nnvm::NodeAttrs& attrs,
+                                        const int dev_mask,
+                                        DispatchMode* dispatch_mode,
+                                        std::vector<int> *in_attrs,
+                                        std::vector<int> *out_attrs) {
+  *dispatch_mode = DispatchMode::kFCompute;
+  if (dev_mask == mshadow::cpu::kDevMask) {
+    *dispatch_mode = DispatchMode::kFComputeEx;
+  }
+
+  for (auto &v : *out_attrs) {
+    v = kDefaultStorage;
+    if (common::stype_string(v).compare("unknown") == 0) {
+      return false;
+    }
+  }
+
+  for (auto &v : *in_attrs) {
+    v = kDefaultStorage;
+    if (common::stype_string(v).compare("unknown") == 0) {
+      return false;
+    }
+  }
+  return true;
+}
+
+struct QuantizedSumInitKernelWithBias {
+  //  init sum data with bias for matrix b (n)
+  MSHADOW_XINLINE static void Map(int i, int32_t *out,
+                                  const int8_t *bias, const float *min_out,
+                                  const float *max_out, const float *min_bias,
+                                  const float *max_bias) {
+    typedef int32_t T1;
+    typedef int8_t  T2;
+    using mshadow::red::limits::MinValue;
+    using mshadow::red::limits::MaxValue;
+    float float_for_one_out_quant  =
+        MaxAbs(*min_out, *max_out) / static_cast<double>(MaxValue<T1>());
+    float float_for_one_bias_quant =
+        MaxAbs(*min_bias, *max_bias) / static_cast<double>(MaxValue<T2>());
+    if (float_for_one_out_quant != 0) {
+      out[i] = bias[i] * float_for_one_bias_quant /
+          float_for_one_out_quant;
+    } else {
+      LOG(INFO) << "float_for_one_out_quant is 0,"
+                << " need to check the why MaxAbs(*min_out, *max_out) of out_data is 0!";
+      out[i] = 0;
+    }
+  }
+};
+
+
+template<typename SrcType>
+void QuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
+                                    const OpContext &ctx,
+                                    const std::vector<NDArray> &in_data,
+                                    const std::vector<OpReqType> &req,
+                                    const std::vector<NDArray> &out_data) {
+#if MSHADOW_USE_MKL == 1
+  const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
+  using namespace mshadow;
+  using namespace mxnet_op;
+  size_t num_inputs = param.no_bias ? 2 : 3;
+  CHECK_EQ(in_data.size(),  num_inputs * 3);
+  CHECK_EQ(out_data.size(), 3U);
+  const NDArray& data = in_data[0];
+  const NDArray& weight = in_data[1];
+  const NDArray& out = out_data[0];
+  TShape dshape = data.shape();
+  TShape wshape = weight.shape();
+  TShape oshape = out.shape();
+  auto output_temp = out.data().dptr<int32_t>();
+  auto weight_temp = weight.data().dptr<SrcType>();
+  auto data_temp = data.data().dptr<SrcType>();
+  const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  const float alpha = 1.0f;
+  const float beta  = 1.0f;
+  const CBLAS_OFFSET offsetc = CblasFixOffset;
+  const MKL_INT8 oa = 0;
+  const MKL_INT8 ob = 0;
+  MKL_INT32 oc = 0;
+  const int m = dshape[0], n = wshape[0], k = dshape.ProdShape(1, dshape.ndim());
+  Stream<cpu> *s = ctx.get_stream<cpu>();
+  //  cblas_gemm_s8u8s32 required first matrix must be uint8
+  //  shift data from int8(from -128 to 127) to uint8 (from 0 to 255)
+  int shift = 128;
+  Tensor<cpu, 1, uint8_t> shiftdata =
+    ctx.requested[quantized_fc::kTempSpace].get_space_typed<cpu, 1, uint8_t>(
+      Shape1(m * k), s);
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < m * k; ++i) {
+    shiftdata.dptr_[i] = data_temp[i] + shift;
+  }
+
+  Kernel<QuantizationRangeForMultiplicationStruct, cpu>::Launch(s, 1,
+      out_data[1].data().dptr<float>(), out_data[2].data().dptr<float>(),
+      in_data[num_inputs].data().dptr<float>(), in_data[num_inputs+1].data().dptr<float>(),
+      in_data[num_inputs+2].data().dptr<float>(), in_data[num_inputs+3].data().dptr<float>());
+  if (!param.no_bias) {
+    const NDArray& bias = in_data[2];
+    Kernel<QuantizedSumInitKernelWithBias, cpu>::Launch(s, n, out.data().dptr<int32_t>(),
+        bias.data().dptr<int8_t>(), out_data[1].data().dptr<float>(),
+        out_data[2].data().dptr<float>(), in_data[7].data().dptr<float>(),
+        in_data[8].data().dptr<float>());
+  } else {
+    #pragma omp parallel for num_threads(omp_threads)
+    for (int i = 0; i < m * n; ++i) {
+      output_temp[i] = 0;
+    }
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = 0; i < n; ++i) {
+    for (int j = 0; j < k; ++j) {
+      output_temp[i] -= shift * weight_temp[i * k + j];
+    }
+  }
+  #pragma omp parallel for num_threads(omp_threads)
+  for (int i = n; i < m * n; ++i) {
+    output_temp[i] = output_temp[i % n];
+  }
+  cblas_gemm_s8u8s32(CblasRowMajor,
+                     CblasNoTrans,
+                     CblasTrans,
+                     offsetc,
+                     m,
+                     n,
+                     k,
+                     alpha,
+                     shiftdata.dptr_,
+                     k,
+                     oa,
+                     weight.data().dptr<SrcType>(),
+                     k,
+                     ob,
+                     beta,
+                     out.data().dptr<int32_t>(),
+                     n,
+                     &oc);
+#else
+  LOG(FATAL) << "Quantized fully connected operator relies on cblas_gemm_s8u8s32"
+             << " which is only supported by MKL BLAS."
+             << " Please build MXNet with USE_BLAS=mkl to leverage this operator.";
+#endif
+}
+
 NNVM_REGISTER_OP(_contrib_quantized_fully_connected)
 .describe(R"code(Fully Connected operator for input, weight and bias data type of int8,
 and accumulates in type int32 for the output. For each argument, two more arguments of type
@@ -112,7 +263,14 @@ and max thresholds representing the threholds for quantizing the float32 output
   })
 .set_attr<nnvm::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
+.set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
+.set_attr<FComputeEx>("FComputeEx<cpu>",
+    QuantizedFullyConnectedForward<int8_t>)
+.set_attr<FResourceRequest>("FResourceRequest",
+  [](const NodeAttrs& attrs) {
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+  })
 .add_argument("data", "NDArray-or-Symbol", "Input data.")
 .add_argument("weight", "NDArray-or-Symbol", "weight.")
 .add_argument("bias", "NDArray-or-Symbol", "bias.")
@@ -135,6 +293,5 @@ NNVM_REGISTER_OP(FullyConnected)
     }
     return node;
   });
-
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py
index ca8070cfc22..a27d5b322e3 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -26,6 +26,7 @@
 from mxnet.module import Module
 from mxnet.io import NDArrayIter
 import unittest
+import operator
 
 def is_test_for_gpu():
     return mx.current_context().device_type == 'gpu'
@@ -278,8 +279,15 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
 def test_quantized_fc():
     def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
         if mx.current_context().device_type != 'gpu':
-            print('skipped testing quantized_fc on cpu since it is not supported yet')
-            return
+            hasMKL = False;
+            for key in os.environ.keys():
+                if operator.eq(key, "BUILD_TAG"):
+                    if os.environ['BUILD_TAG'].find("MKL") != -1:
+                        hasMKL = True
+                    break
+            if hasMKL == False:
+                print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library')
+                return
         elif qdtype == 'uint8' and is_test_for_gpu():
             print('skipped testing quantized_fc for gpu uint8 since it is not supported yet')
             return
@@ -291,16 +299,16 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
         fc_fp32_exe = fc_fp32.simple_bind(ctx=mx.current_context(), grad_req='null')
         if qdtype == 'uint8':
             data_low = 0.0
-            data_high = 127.0
+            data_high = 63.0
         else:
-            data_low = -127.0
-            data_high = 127.0
+            data_low = -63.0
+            data_high = 63.0
         fc_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
                                                                      shape=data_shape).astype('int32')
-        fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+        fc_fp32_exe.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
                                                                      shape=arg_shapes[1]).astype('int32')
         if not no_bias:
-            fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0,
+            fc_fp32_exe.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=data_low, high=data_high,
                                                                          shape=arg_shapes[2]).astype('int32')
         output = fc_fp32_exe.forward()[0]
 
@@ -343,6 +351,10 @@ def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True):
         check_quantized_fc((32, 111, 2, 2), 100, True, qdtype)
         check_quantized_fc((32, 512, 2, 2), 100, False, qdtype)
         check_quantized_fc((32, 111, 2, 2), 100, False, qdtype)
+        check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype)
+        check_quantized_fc((256, 111, 2, 2), 800, False, qdtype)
+        check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype)
+        check_quantized_fc((256, 111, 2, 2), 800, True, qdtype)
 
 @with_seed()
 def test_quantized_flatten():


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services