You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2018/11/28 08:39:23 UTC

[GitHub] TaoLv closed pull request #13297: [MKLDNN]Add quantized concat

TaoLv closed pull request #13297: [MKLDNN]Add quantized concat
URL: https://github.com/apache/incubator-mxnet/pull/13297
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/example/ssd/quantization.py b/example/ssd/quantization.py
index 5cb74ba11a8..231cc99f93b 100644
--- a/example/ssd/quantization.py
+++ b/example/ssd/quantization.py
@@ -139,7 +139,6 @@ def save_params(fname, arg_params, aux_params, logger=None):
     mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]}
 
     if calib_mode == 'none':
-        logger.info('Quantizing FP32 model %s' % args.model)
         qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
                                                        ctx=ctx, excluded_sym_names=excluded_sym_names,
                                                        calib_mode=calib_mode, quantized_dtype=args.quantized_dtype,
diff --git a/src/operator/nn/mkldnn/mkldnn_concat-inl.h b/src/operator/nn/mkldnn/mkldnn_concat-inl.h
new file mode 100644
index 00000000000..d3866cc3d23
--- /dev/null
+++ b/src/operator/nn/mkldnn/mkldnn_concat-inl.h
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_concat-inl.h
+ * \brief
+ * \author Wenting Jiang
+*/
+#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_
+#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_
+
+
+#if MXNET_USE_MKLDNN == 1
+#include <vector>
+#include <utility>
+#include "../concat-inl.h"
+#include "./mkldnn_ops-inl.h"
+#include "./mkldnn_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class MKLDNNConcatFwd {
+ public:
+  mkldnn::concat::primitive_desc fwd_pd;
+
+  MKLDNNConcatFwd(int concat_dim, const std::vector<mkldnn::memory::primitive_desc> &data_md)
+      : fwd_pd(concat_dim, data_md) {
+    data.resize(data_md.size());
+  }
+
+  void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);
+
+  const mkldnn::concat &GetFwd() const;
+
+ private:
+  std::shared_ptr<mkldnn::concat> fwd;
+  std::vector<std::shared_ptr<mkldnn::memory>> data;
+  std::vector<mkldnn::primitive::at> data_mem;
+  std::shared_ptr<mkldnn::memory> out;
+};
+
+static MKLDNNConcatFwd &GetConcatForward(
+    int concat_dim, const std::vector<NDArray> &in_data,
+    const std::vector<mkldnn::memory::primitive_desc> &data_md) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
+#endif
+  OpSignature key;
+  key.AddSign(concat_dim);
+  key.AddSign(in_data);
+
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    MKLDNNConcatFwd fwd(concat_dim, data_md);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_CONCAT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_concat.cc b/src/operator/nn/mkldnn/mkldnn_concat.cc
index 03eeb61eccb..8e2b57781a1 100644
--- a/src/operator/nn/mkldnn/mkldnn_concat.cc
+++ b/src/operator/nn/mkldnn/mkldnn_concat.cc
@@ -22,76 +22,36 @@
  * \brief
  * \author Wenting Jiang
 */
-#include "../concat-inl.h"
-#include "./mkldnn_ops-inl.h"
-#include "./mkldnn_base-inl.h"
 
 #if MXNET_USE_MKLDNN == 1
+#include "mkldnn_concat-inl.h"
+
 namespace mxnet {
 namespace op {
 
-class MKLDNNConcatFwd {
-  std::shared_ptr<mkldnn::concat> fwd;
-  std::vector<std::shared_ptr<mkldnn::memory>> data;
-  std::vector<mkldnn::primitive::at> data_mem;
-  std::shared_ptr<mkldnn::memory> out;
-
- public:
-  mkldnn::concat::primitive_desc fwd_pd;
-
-  MKLDNNConcatFwd(
-      int concat_dim,
-      const std::vector<mkldnn::memory::primitive_desc> &data_md): fwd_pd(concat_dim, data_md) {
-    data.resize(data_md.size());
-  }
-
-  void SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
-                 const mkldnn::memory &output) {
-    CHECK_EQ(in_data.size(), data.size());
-    for (size_t i = 0; i < data.size(); i++) {
-      if (this->data[i] == nullptr) {
-        this->data[i] = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-                in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
-        this->data_mem.push_back(*this->data[i]);
-      } else {
-        this->data[i]->set_data_handle(in_data[i]->get_data_handle());
-      }
+void MKLDNNConcatFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
+                                const mkldnn::memory &output) {
+  CHECK_EQ(in_data.size(), data.size());
+  for (size_t i = 0; i < data.size(); i++) {
+    if (this->data[i] == nullptr) {
+      this->data[i] = std::shared_ptr<mkldnn::memory>(
+          new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
+      this->data_mem.push_back(*this->data[i]);
+    } else {
+      this->data[i]->set_data_handle(in_data[i]->get_data_handle());
     }
-    if (this->out == nullptr)
-      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              fwd_pd.dst_primitive_desc(), output.get_data_handle()));
-    else
-      this->out->set_data_handle(output.get_data_handle());
-
-    if (this->fwd == nullptr)
-      fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out));
   }
+  if (this->out == nullptr)
+    this->out = std::shared_ptr<mkldnn::memory>(
+        new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  else
+    this->out->set_data_handle(output.get_data_handle());
 
-  const mkldnn::concat &GetFwd() const {
-    return *fwd;
-  }
-};
-
-static MKLDNNConcatFwd &GetConcatForward(
-    int concat_dim, const std::vector<NDArray> &in_data,
-    const std::vector<mkldnn::memory::primitive_desc> &data_md) {
-#if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
-#else
-  static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNConcatFwd, OpHash> fwds;
-#endif
-  OpSignature key;
-  key.AddSign(concat_dim);
-  key.AddSign(in_data);
-
-  auto it = fwds.find(key);
-  if (it == fwds.end()) {
-    MKLDNNConcatFwd fwd(concat_dim, data_md);
-    it = AddToCache(&fwds, key, fwd);
-  }
-  return it->second;
+  if (this->fwd == nullptr) fwd.reset(new mkldnn::concat(fwd_pd, data_mem, *out));
 }
 
+const mkldnn::concat &MKLDNNConcatFwd::GetFwd() const { return *fwd; }
+
 void MKLDNNConcatForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
                          const std::vector<NDArray> &in_data,
                          const std::vector<OpReqType> &req,
diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
new file mode 100644
index 00000000000..d9e884e8280
--- /dev/null
+++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file quantized_concat.cc
+ * \brief
+ */
+
+#if MXNET_USE_MKLDNN == 1
+#include "../../nn/mkldnn/mkldnn_concat-inl.h"
+#include "../quantization_utils.h"
+
+namespace mxnet {
+namespace op {
+
+namespace quantized_concat_enum {
+enum QuantizedConcatOutputs { kOut, kMin, kMax };
+}
+
+static float GetScale(const NDArray& data, float min, float max) {
+  auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range;
+  return data_range / MaxAbs(min, max);
+}
+
+static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
+                                         const std::vector<NDArray>& in_data,
+                                         const std::vector<OpReqType>& req,
+                                         const std::vector<NDArray>& out_data) {
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_data.size(), static_cast<size_t>(param_.num_args * 3));
+  CHECK_EQ(out_data.size(), 3U);
+  // Collect data min/max and output_neg_min, output_pos_max
+  std::vector<float> data_min(param_.num_args);
+  std::vector<float> data_max(param_.num_args);
+  float output_neg_min = 0.f;  // 0.f is the maximum for output_neg_min
+  float output_pos_max = 0.f;  // 0.f is the minimum for output_pos_max
+  for (int i = 0; i < param_.num_args; ++i) {
+    data_min[i] = in_data[param_.num_args + 2 * i].data().dptr<float>()[0];
+    if (data_min[i] < output_neg_min) output_neg_min = data_min[i];
+    data_max[i] = in_data[param_.num_args + 2 * i + 1].data().dptr<float>()[0];
+    if (data_max[i] > output_pos_max) output_pos_max = data_max[i];
+  }
+  out_data[quantized_concat_enum::kMin].data().dptr<float>()[0] = output_neg_min;
+  out_data[quantized_concat_enum::kMax].data().dptr<float>()[0] = output_pos_max;
+  auto out_scale = GetScale(out_data[quantized_concat_enum::kOut], output_neg_min, output_pos_max);
+  std::vector<mkldnn::memory::primitive_desc> data_md;
+  std::vector<const mkldnn::memory*> data_mem;
+  // new_data_mem is for auto-free new created mkldnn memory
+  std::vector<std::shared_ptr<mkldnn::memory>> new_data_mem;
+  for (int i = 0; i < param_.num_args; ++i) {
+    auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]);
+    if (i_scale == out_scale) {
+      auto mem = in_data[i].GetMKLDNNData();
+      data_mem.push_back(mem);
+      data_md.push_back(mem->get_primitive_desc());
+    } else {
+      auto mem = in_data[i].GetMKLDNNData();
+      auto pd = mem->get_primitive_desc();
+      const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
+      new_data_mem.push_back(rescaled_mem);
+      std::vector<float> reorder_scale = {out_scale / i_scale};
+      primitive_attr reorder_attr;
+      reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
+      reorder_attr.set_output_scales(0, reorder_scale);
+      const auto reorder_pd = mkldnn::reorder::primitive_desc(pd, pd, reorder_attr);
+      MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem));
+      data_mem.push_back(rescaled_mem.get());
+      data_md.push_back(pd);
+    }
+  }
+  MKLDNNConcatFwd& fwd = GetConcatForward(param_.dim, in_data, data_md);
+  mxnet::mkldnn_output_t out_mem =
+      CreateMKLDNNMem(out_data[quantized_concat_enum::kOut], fwd.fwd_pd.dst_primitive_desc(),
+                      req[concat_enum::kOut]);
+  fwd.SetNewMem(data_mem, *out_mem.second);
+  MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
+  CommitOutput(out_data[concat_enum::kOut], out_mem);
+  MKLDNNStream::Get()->Submit();
+}
+
+inline static bool ConcatStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
+                                     DispatchMode* dispatch_mode, std::vector<int>* in_attrs,
+                                     std::vector<int>* out_attrs) {
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_attrs->size(), static_cast<size_t>(param_.num_args * 3));
+  CHECK_EQ(out_attrs->size(), 3U);
+
+  return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_concat)
+.set_attr<FInferStorageType>("FInferStorageType", ConcatStorageType)
+.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedConcatForward)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+})
+.set_attr<bool>("TIsMKLDNN", true);
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_MKLDNN == 1
diff --git a/src/operator/quantization/quantization_utils.h b/src/operator/quantization/quantization_utils.h
index 5b096ac0057..ee711220589 100644
--- a/src/operator/quantization/quantization_utils.h
+++ b/src/operator/quantization/quantization_utils.h
@@ -31,6 +31,8 @@
 namespace mxnet {
 namespace op {
 
+static const size_t kUint8Range = 255;
+static const size_t kInt8Range = 127;
 
 template<typename T>
 MSHADOW_XINLINE int Sign(T val) {
diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc
new file mode 100644
index 00000000000..3504df82d24
--- /dev/null
+++ b/src/operator/quantization/quantized_concat.cc
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2018 by Contributors
+ * \file quantized_concat.cc
+ * \brief
+*/
+
+#include "../nn/concat-inl.h"
+
+namespace mxnet {
+namespace op {
+
+static bool ConcatShape(const nnvm::NodeAttrs& attrs, std::vector<TShape>* in_shape,
+                        std::vector<TShape>* out_shape) {
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args * 3));
+  CHECK_EQ(out_shape->size(), 3U);
+  TShape dshape;
+  index_t size = 0;
+  bool has_zero = false;
+  int axis = -1;
+  for (int i = 0; i < param_.num_args; ++i) {
+    TShape tmp = (*in_shape)[i];
+    if (tmp.ndim()) {
+      axis = CheckAxis(param_.dim, tmp.ndim());
+      has_zero = tmp[axis] == 0 || has_zero;
+      size += tmp[axis];
+      tmp[axis] = 0;
+      shape_assign(&dshape, tmp);
+    }
+  }
+
+  TShape tmp = (*out_shape)[0];
+  if (tmp.ndim()) {
+    axis = CheckAxis(param_.dim, tmp.ndim());
+    tmp[axis] = 0;
+    shape_assign(&dshape, tmp);
+  }
+
+  if (dshape.ndim() == 0) return false;
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    CHECK(shape_assign(&(*in_shape)[i], dshape))
+        << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
+  }
+
+  if (!has_zero) dshape[axis] = size;
+  CHECK(shape_assign(&(*out_shape)[0], dshape))
+      << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
+
+  for (int i = param_.num_args; i < param_.num_args * 3; ++i) {
+    SHAPE_ASSIGN_CHECK(*in_shape, i, TShape{1});
+  }
+  SHAPE_ASSIGN_CHECK(*out_shape, 1, TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_shape, 2, TShape{1});
+  return dshape.Size() != 0;
+}
+
+static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_type,
+                       std::vector<int>* out_type) {
+  const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
+  CHECK_EQ(in_type->size(), static_cast<size_t>(param_.num_args * 3));
+  CHECK_EQ(out_type->size(), 3U);
+  int dtype = mshadow::kUint8;
+
+  for (int i = 0; i < param_.num_args; ++i) {
+    if (in_type->at(i) == mshadow::kInt8) {
+      dtype = mshadow::kInt8;
+    } else {
+      TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kUint8);
+    }
+  }
+  TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
+  TYPE_ASSIGN_CHECK(*out_type, 1, mshadow::kFloat32);
+  TYPE_ASSIGN_CHECK(*out_type, 2, mshadow::kFloat32);
+
+  return true;
+}
+
+NNVM_REGISTER_OP(_contrib_quantized_concat)
+.describe(R"code(Joins input arrays along a given axis.
+
+The dimensions of the input arrays should be the same except the axis along
+which they will be concatenated.
+The dimension of the output array along the concatenated axis will be equal
+to the sum of the corresponding dimensions of the input arrays.
+All inputs with different min/max will be rescaled by using largest [min, max] pairs.
+If any input holds int8, then the output will be int8. Otherwise output will be uint8.
+
+)code" ADD_FILELINE)
+.set_num_inputs([](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  return params.num_args * 3;
+})
+.set_num_outputs(3)
+.set_attr_parser(ParamParser<ConcatParam>)
+.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+  const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
+  std::vector<std::string> ret;
+  for (int i = 0; i < params.num_args; ++i) {
+    ret.push_back(std::string("arg") + std::to_string(i));
+  }
+  for (int i = 0; i < params.num_args; ++i) {
+    ret.push_back(std::string("arg") + std::to_string(i) + "_min");
+    ret.push_back(std::string("arg") + std::to_string(i) + "_max");
+  }
+  return ret;
+})
+.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output", "min_output", "max_output"};
+})
+.set_attr<nnvm::FInferType>("FInferType", ConcatType)
+.set_attr<nnvm::FInferShape>("FInferShape", ConcatShape)
+.set_attr<std::string>("key_var_num_args", "num_args")
+.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
+.add_arguments(ConcatParam::__FIELDS__());
+
+NNVM_REGISTER_OP(Concat)
+.set_attr<FQuantizedOp>("FQuantizedOp", [](const NodeAttrs& attrs) {
+  nnvm::NodePtr node = nnvm::Node::Create();
+  node->attrs.op = Op::Get("_contrib_quantized_concat");
+  node->attrs.name = "quantized_" + attrs.name;
+  node->attrs.dict = attrs.dict;
+  if (node->op()->attr_parser != nullptr) {
+    node->op()->attr_parser(&(node->attrs));
+  }
+  return node;
+});
+
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h
index 8675446f5a1..b44f2fb0e31 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv-inl.h
@@ -36,9 +36,6 @@ struct MKLDNNConvFusionParam {
   std::shared_ptr<BatchNormParam> bn_param;
 };
 
-static const size_t uint8_range = 255;
-static const size_t int8_range = 127;
-
 enum MKLDNNConvOpOutputs { kOut, kMin, kMax };
 
 }  // namespace op
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index a1083d09b7b..dfa98d1f5ee 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -109,7 +109,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias,
 #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
     for (int c = 0; c < static_cast<int>(channel); ++c) {
       DType weight_range = MaxAbs(weight_c_min[c], weight_c_max[c]);
-      weight_scales->at(c) = int8_range / weight_range;
+      weight_scales->at(c) = kInt8Range / weight_range;
       const DType *fp_ptr = weight_ptr + c * offset;
       int8_t *quan_ptr = quan_weight_ptr + c * offset;
       for (size_t k = 0; k < offset; ++k) {
@@ -125,7 +125,7 @@ static void QuantizeConvWeightBias(NDArray *weight, NDArray *bias,
     }
     weight_scales->resize(1);
     DType weight_range = MaxAbs(total_min, total_max);
-    weight_scales->at(0) = int8_range / weight_range;
+    weight_scales->at(0) = kInt8Range / weight_range;
 #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
     for (int c = 0; c < static_cast<int>(channel); ++c) {
       const DType *fp_ptr = weight_ptr + c * offset;
@@ -327,7 +327,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
     // Quantize weight and bias.
     if (mkldnn_param.quantized) {
       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
-      auto data_range = (data.dtype() == mshadow::kInt8) ? int8_range : uint8_range;
+      auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : kUint8Range;
       float data_scale = data_range / MaxAbs(cached_data_min_, cached_data_max_);
       MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
         QuantizeConvWeightBias<DType>(&cached_weight_, &cached_bias_,
@@ -346,12 +346,12 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
         LOG(FATAL) << "Can't handle negetive value for QuantizeData";
       }
       if (mkldnn_param.with_sum) {
-        auto quantized_sum_range = cached_sum_min_ < 0 ? int8_range : uint8_range;
+        auto quantized_sum_range = cached_sum_min_ < 0 ? kInt8Range : kUint8Range;
         sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_);
       }
       if (post_requantize) {
         quantized_out_range =
-            IsOutputUInt8(mkldnn_param) ? uint8_range : int8_range;
+            IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range;
         out_range = MaxAbs(*out_min_ptr, *out_max_ptr);
         output_scale = quantized_out_range / out_range;
         full_conv_param.requantize_scales.resize(channel);
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 71784dcd3bf..be6feaeb94a 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -66,7 +66,7 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
     output.wait_to_read()
   return mod.get_outputs()
 
-def check_quantize(sym, data_shape):
+def check_quantize(sym, data_shape, check_conv=True):
   fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
   sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
   sym_sg = sym.get_backend_symbol("MKLDNN")
@@ -106,7 +106,8 @@ def check_quantize(sym, data_shape):
                                                                    calib_quantize_op=True,
                                                                    num_calib_examples=5)
   qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
-  check_qsym_calibrated(qsym)
+  if check_conv:
+    check_qsym_calibrated(qsym)
   quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
   for i in range(len(ref_out)):
     assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
@@ -229,6 +230,15 @@ def conv_bn_sum_relu(no_bias, data_shape):
   relu = mx.symbol.Activation(data=sum1, name='relu', act_type="relu")
   return relu, conv_bn_add_relu_attr
 
+# single concat case
+def single_concat(data_shape, input_num, dim):
+  data, weight = head_symbol(data_shape)
+  inputs = []
+  for i in range(input_num):
+    inputs.append(data)
+  concat = mx.symbol.Concat(*inputs, name="concat", dim=dim)
+  return concat
+
 def tail_neg_symbol(sym1, sym2):
   fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1')
   fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2')
@@ -463,6 +473,15 @@ def test_pos_conv_bn_sum_relu():
     net, attrs = conv_bn_sum_relu(True, data_shape)
     check_fusion(net, data_shape, attrs)
 
+def test_pos_single_concat():
+  for data_shape in DATA_SHAPE:
+    net = single_concat(data_shape, 2, 1)
+    check_quantize(net, data_shape, False)
+    net = single_concat(data_shape, 4, 2)
+    check_quantize(net, data_shape, False)
+    net = single_concat(data_shape, 4, 3)
+    check_quantize(net, data_shape, False)
+
 @with_seed()
 def test_neg_conv_bn():
   for data_shape in DATA_SHAPE:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services