You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/09/19 00:30:13 UTC

[GitHub] zheng-da commented on a change in pull request #12530: Implement mkldnn convolution fusion and quantization.

zheng-da commented on a change in pull request #12530: Implement mkldnn convolution fusion and quantization.
URL: https://github.com/apache/incubator-mxnet/pull/12530#discussion_r218615667
 
 

 ##########
 File path: src/operator/subgraph/mkldnn/mkldnn_conv.cc
 ##########
 @@ -0,0 +1,670 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+
+#if MXNET_USE_MKLDNN == 1
+#include "../common.h"
+#include "../../nn/convolution-inl.h"
+#include "../../nn/batch_norm-inl.h"
+#include "../../nn/activation-inl.h"
+#include "../../nn/mkldnn/mkldnn_base-inl.h"
+#include "../../nn/mkldnn/mkldnn_ops-inl.h"
+#include "../../nn/mkldnn/mkldnn_convolution-inl.h"
+#include "../../quantization/quantization_utils.h"
+namespace mxnet {
+namespace op {
+
+struct MKLDNNConvFusionParam {
+  MKLDNNConvFullParam full_conv_param;
+  std::shared_ptr<BatchNormParam> bn_param;
+};
+
+static const size_t uint8_range = 255;
+static const size_t int8_range = 127;
+
+enum MKLDNNConvOpOutputs { kOut, kMin, kMax };
+
+template <typename DType>
+static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias,
+                                 const NDArray &gamma, const NDArray &beta,
+                                 const NDArray &mean, const NDArray &variance,
+                                 const BatchNormParam *param) {
+  // TODO(Zhennan): Handle the case weight is not in dims 4.
+  NDArray update_weight = NDArray(weight->storage_type(), weight->shape(),
+                                  weight->ctx(), true, weight->dtype());
+  NDArray update_bias = NDArray(beta.storage_type(), beta.shape(), beta.ctx(),
+                                true, beta.dtype());
+  DType *weight_ptr = weight->data().dptr<DType>();
+  DType *bias_ptr = no_bias ? nullptr : bias->data().dptr<DType>();
+  DType *gamma_ptr = gamma.Reorder2Default().data().dptr<DType>();
+  DType *beta_ptr = beta.Reorder2Default().data().dptr<DType>();
+  DType *mean_ptr = mean.Reorder2Default().data().dptr<DType>();
+  DType *var_ptr = variance.Reorder2Default().data().dptr<DType>();
+  DType *update_weight_ptr = update_weight.data().dptr<DType>();
+  DType *update_bias_ptr = update_bias.data().dptr<DType>();
+  size_t channel = gamma.shape()[0];
+  size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3];
+#pragma omp parallel for
+  for (int c = 0; c < static_cast<int>(channel); ++c) {
+    DType *p1 = reinterpret_cast<DType *>(weight_ptr + c * offset);
+    DType *p2 = reinterpret_cast<DType *>(update_weight_ptr + c * offset);
+    DType alpha = (param->fix_gamma ? static_cast<DType>(1.0f) : gamma_ptr[c]) /
+                  sqrt(var_ptr[c] + param->eps);
+
+    if (bias_ptr)
+      update_bias_ptr[c] = beta_ptr[c] + alpha * (bias_ptr[c] - mean_ptr[c]);
+    else
+      update_bias_ptr[c] = beta_ptr[c] - alpha * mean_ptr[c];
+
+    for (size_t k = 0; k < offset; ++k) {
+      p2[k] = p1[k] * alpha;
+    }
+  }
+  *weight = update_weight;
+  *bias = update_bias;
+}
+
+static inline size_t GetInSumIndex(const MKLDNNConvFusionParam &param) {
+  return 2 + (param.full_conv_param.conv_param.no_bias ? 0 : 1) +
+         (param.full_conv_param.mkldnn_param.with_bn ? 4 : 0);
+}
+
+template <typename DType>
+static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias,
+                                   bool has_bias, float data_min,
+                                   float data_max,
+                                   bool weight_channelwise_scale,
+                                   std::vector<float> *weight_scales) {
+  using red::limits::MaxValue;
+  using red::limits::MinValue;
+  DType *weight_ptr = weight->data().dptr<DType>();
+  NDArray quantized_weight = NDArray(weight->storage_type(), weight->shape(),
+                                     weight->ctx(), true, mshadow::kInt8);
+  int8_t *quan_weight_ptr = quantized_weight.data().dptr<int8_t>();
+  size_t channel = weight->shape()[0];
+
+  // TODO(Zhennan): Handle the case weight is not in dims 4.
+  size_t offset = weight->shape()[1] * weight->shape()[2] * weight->shape()[3];
+  std::vector<DType> weight_c_min(channel, MaxValue<DType>());
+  std::vector<DType> weight_c_max(channel, MinValue<DType>());
+#pragma omp parallel for
+  for (int c = 0; c < static_cast<int>(channel); ++c) {
+    DType *p1 = weight_ptr + c * offset;
+    for (size_t k = 0; k < offset; ++k) {
+      if (weight_c_min[c] > p1[k])
+        weight_c_min[c] = p1[k];
+      if (weight_c_max[c] < p1[k])
+        weight_c_max[c] = p1[k];
+    }
+  }
+
+  if (weight_channelwise_scale) {
+    weight_scales->resize(channel);
+#pragma omp parallel for
+    for (int c = 0; c < static_cast<int>(channel); ++c) {
+      DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]);
+      weight_scales->at(c) = int8_range / weight_range;
+      DType *fp_ptr = weight_ptr + c * offset;
+      int8_t *quan_ptr = quan_weight_ptr + c * offset;
+      for (size_t k = 0; k < offset; ++k) {
+        quan_ptr[k] = std::round(weight_scales->at(c) * fp_ptr[k]);
+      }
+    }
+  } else {
+    DType total_min = weight_c_min[0];
+    DType total_max = weight_c_max[0];
+    for (size_t c = 0; c < channel; ++c) {
+      if (total_min > weight_c_min[c]) total_min = weight_c_min[c];
+      if (total_max < weight_c_max[c]) total_max = weight_c_max[c];
+    }
+    weight_scales->resize(1);
+    DType weight_range = MaxAbs(total_min, total_max);
+    weight_scales->at(0) = int8_range / weight_range;
+#pragma omp parallel for
+    for (int c = 0; c < static_cast<int>(channel); ++c) {
+      DType *fp_ptr = weight_ptr + c * offset;
+      int8_t *quan_ptr = quan_weight_ptr + c * offset;
+      for (size_t k = 0; k < offset; ++k) {
+        quan_ptr[k] = std::round(weight_scales->at(0) * fp_ptr[k]);
+      }
+    }
+  }
+
+  *weight = quantized_weight;
+  if (has_bias) {
+    DType *bias_ptr = bias->data().dptr<DType>();
+    NDArray quantized_bias = NDArray(bias->storage_type(), bias->shape(),
+                                     bias->ctx(), true, mshadow::kInt32);
+    int32_t *quan_bias_ptr = quantized_bias.data().dptr<int32_t>();
+    DType data_scale = uint8_range / MaxAbs(data_min, data_max);
+    for (size_t c = 0; c < channel; ++c) {
+      auto weight_scale =
+          weight_channelwise_scale ? weight_scales->at(c) : weight_scales->at(0);
+      float bias_scale = weight_scale * data_scale;
+      quan_bias_ptr[c] = std::round(bias_scale * bias_ptr[c]);
+    }
+    *bias = quantized_bias;
+  }
+}
+
+static void ConvFusionFallBackCompute() {
+  LOG(FATAL) << "Don't know how to do ConvFusionFallBackCompute!";
+}
+
+static void ConvolutionFusionComputeExCPU(const MKLDNNConvFullParam &full_param,
+                                          const OpContext &ctx,
+                                          MKLDNNConvForward *fwd,
+                                          const std::vector<NDArray> &inputs,
+                                          const std::vector<OpReqType> &req,
+                                          const std::vector<NDArray> &outputs) {
+  if (SupportMKLDNNConv(full_param.conv_param, inputs[0])) {
+    // MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNConvolutionForwardFullFeature(full_param, ctx, fwd, inputs, req, outputs);
+    // MKLDNN_OPCHECK_RUN(ConvolutionCompute<cpu>, attrs, ctx, inputs, req,
+    // outputs);
+    return;
+  }
+  ConvFusionFallBackCompute();
+}
+
+class SgMKLDNNConvOperator {
+ public:
+  explicit SgMKLDNNConvOperator(const nnvm::NodeAttrs &attrs)
+      : initalized(false),
+        subgraph_sym_(*attrs.subgraphs[0]),
+        param(nnvm::get<MKLDNNConvFusionParam>(attrs.parsed)) {}
+
+  void Forward(const OpContext &ctx,
+               const std::vector<NDArray> &inputs,
+               const std::vector<OpReqType> &req,
+               const std::vector<NDArray> &outputs);
+
+  void Backward(const OpContext &ctx, const std::vector<NDArray> &inputs,
+                const std::vector<OpReqType> &req,
+                const std::vector<NDArray> &outputs) {
+    LOG(FATAL) << "Not implemented: subgraph mkldnn Conv only supports "
+                  "inference computation";
+  }
+
+ private:
+  bool initalized;
+  nnvm::Symbol subgraph_sym_;
+  MKLDNNConvFusionParam param;
+  std::shared_ptr<MKLDNNConvForward> fwd;
+  NDArray cached_weight_;
+  NDArray cached_bias_;
+  NDArray cached_data_;
+  NDArray cached_output_;
+  float cached_data_min;
+  float cached_data_max;
+  float cached_sum_min;
+  float cached_sum_max;
+  size_t weight_ver;
+  size_t bias_ver;
+  std::vector<float> weight_scales;
+};
+
+void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
+                                   const std::vector<NDArray> &inputs,
+                                   const std::vector<OpReqType> &req,
+                                   const std::vector<NDArray> &outputs) {
+  auto &full_conv_param = param.full_conv_param;
+  auto &mkldnn_param = param.full_conv_param.mkldnn_param;
+  auto &conv_param = param.full_conv_param.conv_param;
+  auto bn_param = param.bn_param.get();
+  size_t input_size =
+      2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) +
+      (mkldnn_param.with_sum ? 1 : 0) +
+      (mkldnn_param.quantized
+           ? 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 0)
+           : 0);
+  CHECK_EQ(inputs.size(), input_size);
+  size_t idx = 0;
+
+  auto in_data = idx++;
+  auto in_weight = idx++;
+  auto in_bias = conv_param.no_bias ? 0 : (idx++);
+  auto in_gamma = mkldnn_param.with_bn ? (idx++) : 0;
+  auto in_beta = mkldnn_param.with_bn ? (idx++) : 0;
+  auto in_mean = mkldnn_param.with_bn ? (idx++) : 0;
+  auto in_var = mkldnn_param.with_bn ? (idx++) : 0;
+  auto in_sum = mkldnn_param.with_sum ? (idx++) : 0;
+  float data_min =
+      mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
+  float data_max =
+      mkldnn_param.quantized ? inputs[idx++].data().dptr<float>()[0] : 0.0;
+  float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized)
+                      ? inputs[idx++].data().dptr<float>()[0]
+                      : 0.0;
+  float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized)
+                      ? inputs[idx++].data().dptr<float>()[0]
+                      : 0.0;
+  float *out_min_ptr =
+      mkldnn_param.quantized ? outputs[kMin].data().dptr<float>() : nullptr;
+  float *out_max_ptr =
+      mkldnn_param.quantized ? outputs[kMax].data().dptr<float>() : nullptr;
+  CHECK_EQ(input_size, idx);
+  bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias;
+  cached_data_ = inputs[in_data];
+  if (mkldnn_param.with_sum)
+    cached_output_ = inputs[in_sum];
+  else
+    cached_output_ = outputs[kOut];
 
 Review comment:
   why cache input data and output? The input and output data may be reused somewhere else because of the memory planning in MXNet.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services