You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by re...@apache.org on 2019/03/06 07:20:39 UTC

[incubator-mxnet] branch master updated: Register fake grad to subgraph and quantized operators (#14275)

This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b486594  Register fake grad to subgraph and quantized operators (#14275)
b486594 is described below

commit b48659429ff15eeea9aeb9055b6b27527c2d108f
Author: Xinyu Chen <xi...@intel.com>
AuthorDate: Wed Mar 6 15:20:17 2019 +0800

    Register fake grad to subgraph and quantized operators (#14275)
    
    * add fake grad
    
    * Skip inference only subgraph pass when gradient is needed.
    
    * add fake grad to quantizev2
    
    * add TODO
    
    * modify prop_name to property_name
    
    * add test case
---
 src/executor/graph_executor.cc                     | 34 +++++++++++---
 src/operator/quantization/dequantize.cc            |  3 ++
 src/operator/quantization/quantize.cc              |  3 ++
 src/operator/quantization/quantize_v2.cc           |  3 ++
 src/operator/quantization/quantized_concat.cc      |  3 ++
 src/operator/quantization/quantized_conv.cc        |  3 ++
 src/operator/quantization/quantized_flatten.cc     |  3 ++
 .../quantization/quantized_fully_connected.cc      |  3 ++
 src/operator/quantization/quantized_pooling.cc     |  3 ++
 src/operator/quantization/requantize.cc            |  3 ++
 src/operator/subgraph/mkldnn/mkldnn_conv.cc        |  3 ++
 .../mkldnn/mkldnn_conv_post_quantize_property.cc   |  6 ++-
 .../subgraph/mkldnn/mkldnn_conv_property.cc        |  5 +-
 src/operator/subgraph/subgraph_property.h          |  7 +++
 tests/python/mkl/test_subgraph.py                  | 53 +++++++++++++++++-----
 15 files changed, 114 insertions(+), 21 deletions(-)

diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index ca2cea0..436eae3 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1506,8 +1506,26 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
                                    const Context& default_ctx,
                                    const std::map<std::string, Context>& ctx_map,
                                    const std::vector<Context>& in_arg_ctxes,
-                                   const std::vector<Context>& aux_state_ctxes) {
+                                   const std::vector<Context>& aux_state_ctxes,
+                                   const std::vector<OpReqType>& grad_req_types) {
   auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+  bool need_grad = false;
+  for (OpReqType req : grad_req_types) {
+    if (req != kNullOp) {
+      need_grad = true;
+      break;
+    }
+  }
+  if (subgraph_prop->HasAttr("inference_only") &&
+      subgraph_prop->GetAttr<bool>("inference_only") == true) {
+    if (need_grad) {
+      auto full_name = subgraph_prop->HasAttr("prop_name")
+                            ? subgraph_prop->GetAttr<std::string>("prop_name")
+                            : prop_name;
+      LOG(INFO) << "Skip subgraph " << full_name << " as it requires `grad_req=null`.";
+      return src;
+    }
+  }
   nnvm::Symbol ret = src.Copy();
   nnvm::Graph g;
   g.outputs = ret.outputs;
@@ -1539,7 +1557,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
                                    const Context& default_ctx,
                                    const std::map<std::string, Context>& ctx_map,
                                    const std::vector<Context>& in_arg_ctxes,
-                                   const std::vector<Context>& aux_state_ctxes) {
+                                   const std::vector<Context>& aux_state_ctxes,
+                                   const std::vector<OpReqType>& grad_req_types) {
   const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
   mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape());
   nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
@@ -1559,7 +1578,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
     }
   }
   return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
-                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types);
 }
 
 // Given input ndarrays, partition the graph using the backend name equal to prop_name.
@@ -1569,7 +1588,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
                                    std::vector<NDArray> *in_args,
                                    const std::vector<NDArray> &aux_states,
                                    const Context& default_ctx,
-                                   const std::map<std::string, Context>& ctx_map) {
+                                   const std::map<std::string, Context>& ctx_map,
+                                   const std::vector<OpReqType>& grad_req_types) {
   const std::vector<std::string> input_names = src.ListInputNames(Symbol::kAll);
   const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
   const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
@@ -1609,7 +1629,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
     in_args_map[arg_names[i]] = in_args->at(i);
   }
   auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
-                               ctx_map, in_arg_ctxes, aux_state_ctxes);
+                               ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types);
   // Reorder in_args into new_in_args according to partitioned symbol input sequence
   std::vector<NDArray> new_in_args(in_args->size());
   // get new symbol in_arg names
@@ -1644,7 +1664,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
   if (!exec->subgraph_property().empty()) {
     symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
                                   arg_stype_map, default_ctx, group2ctx, in_arg_ctxes,
-                                  aux_state_ctxes);
+                                  aux_state_ctxes, grad_req_types);
   }
   exec->Init(symbol, default_ctx, group2ctx,
              in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
@@ -1667,7 +1687,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
   std::vector<NDArray> tmp_in_args = in_args;
   if (!exec->subgraph_property().empty()) {
     symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states,
-                                  default_ctx, group2ctx);
+                                  default_ctx, group2ctx, grad_req_type);
   }
   exec->Init(symbol, default_ctx, group2ctx,
              tmp_in_args, arg_grad_store, grad_req_type, aux_states,
diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc
index a4d57b9..7c84673 100644
--- a/src/operator/quantization/dequantize.cc
+++ b/src/operator/quantization/dequantize.cc
@@ -71,6 +71,9 @@ by keep zero centered for the quantized value:
 .set_attr<mxnet::FInferShape>("FInferShape", DequantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", DequantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", DequantizeStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNDequantizeCompute)
diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc
index c28d8c8..6346750 100644
--- a/src/operator/quantization/quantize.cc
+++ b/src/operator/quantization/quantize.cc
@@ -82,6 +82,9 @@ where
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizeStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeCompute)
diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc
index 300cdfe..920100b 100644
--- a/src/operator/quantization/quantize_v2.cc
+++ b/src/operator/quantization/quantize_v2.cc
@@ -83,6 +83,9 @@ If min_calib_range isn't presented, the output type will be int8.
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizeV2Shape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizeV2Type)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizeV2StorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc
index f5c1e8e..e32bb5a 100644
--- a/src/operator/quantization/quantized_concat.cc
+++ b/src/operator/quantization/quantized_concat.cc
@@ -127,6 +127,9 @@ If any input holds int8, then the output will be int8. Otherwise output will be
 .set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
   return std::vector<std::string>{"output", "min_output", "max_output"};
 })
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<nnvm::FInferType>("FInferType", ConcatType)
 .set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
 .set_attr<std::string>("key_var_num_args", "num_args")
diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc
index 7841c3a..1a801ee 100644
--- a/src/operator/quantization/quantized_conv.cc
+++ b/src/operator/quantization/quantized_conv.cc
@@ -160,6 +160,9 @@ and max thresholds representing the threholds for quantizing the float32 output
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedConvShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedConvType)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizedConvStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FResourceRequest>("FResourceRequest",
   [](const NodeAttrs& attrs) {
     return std::vector<ResourceRequest>(1, ResourceRequest::kTempSpace);
diff --git a/src/operator/quantization/quantized_flatten.cc b/src/operator/quantization/quantized_flatten.cc
index f283d98..7e6d27b 100644
--- a/src/operator/quantization/quantized_flatten.cc
+++ b/src/operator/quantization/quantized_flatten.cc
@@ -34,6 +34,9 @@ NNVM_REGISTER_OP(_contrib_quantized_flatten)
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedFlattenShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedFlattenType)
 .set_attr<FCompute>("FCompute<cpu>", QuantizedFlattenCompute<cpu>)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<nnvm::FListInputNames>("FListInputNames",
   [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"data", "min_data", "max_data"};
diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc
index f51b6fd..3b18e65 100644
--- a/src/operator/quantization/quantized_fully_connected.cc
+++ b/src/operator/quantization/quantized_fully_connected.cc
@@ -264,6 +264,9 @@ and max thresholds representing the threholds for quantizing the float32 output
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedFullyConnectedShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedFullyConnectedType)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizedFullyConnectedStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { return true; })
 .set_attr<FComputeEx>("FComputeEx<cpu>",
     QuantizedFullyConnectedForward<int8_t>)
diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc
index cdc98ee..af60408 100644
--- a/src/operator/quantization/quantized_pooling.cc
+++ b/src/operator/quantization/quantized_pooling.cc
@@ -157,6 +157,9 @@ the float32 data into int8.
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizedPoolingShape)
 .set_attr<nnvm::FInferType>("FInferType", QuantizedPoolingType)
 .set_attr<FInferStorageType>("FInferStorageType", QuantizedPoolingStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FNeedRequantize>("FNeedRequantize",
   [](const NodeAttrs& attrs) {
     const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc
index edfb58e..4807226 100644
--- a/src/operator/quantization/requantize.cc
+++ b/src/operator/quantization/requantize.cc
@@ -64,6 +64,9 @@ inference accuracy.
 .set_attr<mxnet::FInferShape>("FInferShape", QuantizeShape)
 .set_attr<nnvm::FInferType>("FInferType", RequantizeType)
 .set_attr<FInferStorageType>("FInferStorageType", RequantizeStorageType)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 #if MXNET_USE_MKLDNN == 1
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNRequantizeForward)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
index e53ab25..d61b461 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc
@@ -689,6 +689,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv)
 .set_attr<FInferStorageType>("FInferStorageType", SgMKLDNNConvOpStorageType)
 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNConvOpForward)
 .set_attr<bool>("TIsMKLDNN", true)
+// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
+// will be reverted after the improvement of CachedOP is done.
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc
index fc68287..654f6e7 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc
@@ -107,7 +107,11 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty {
     }
   }
   static SubgraphPropertyPtr Create() {
-    return std::make_shared<SgMKLDNNConvPostQuantizeProperty>();
+    auto property = std::make_shared<SgMKLDNNConvPostQuantizeProperty>();
+    property->SetAttr<std::string>("property_name",
+                                   "MKLDNN Convolution post-quantization optimization pass");
+    property->SetAttr<bool>("inference_only", true);
+    return property;
   }
   nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                    const int subgraph_id = 0) const override {
diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
index e462191..56ce729 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc
@@ -149,7 +149,10 @@ class SgMKLDNNConvProperty : public SubgraphProperty {
     }
   }
   static SubgraphPropertyPtr Create() {
-    return std::make_shared<SgMKLDNNConvProperty>();
+    auto property = std::make_shared<SgMKLDNNConvProperty>();
+    property->SetAttr<std::string>("prop_name", "MKLDNN Convolution optimization pass");
+    property->SetAttr<bool>("inference_only", true);
+    return property;
   }
   nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                    const int subgraph_id = 0) const override {
diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h
index e9fdd66..d115d34 100644
--- a/src/operator/subgraph/subgraph_property.h
+++ b/src/operator/subgraph/subgraph_property.h
@@ -145,6 +145,13 @@ class SubgraphProperty {
     CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty";
     return nnvm::get<T>(*it->second);
   }
+  /*!
+   * \brief Check if the attr exist.
+   */
+  bool HasAttr(const std::string& name) const {
+    auto it = attrs_.find(name);
+    return it != attrs_.end();
+  }
 
  protected:
   std::unordered_map<std::string, std::shared_ptr<nnvm::any>> attrs_;
diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py
index 313668c..8de854c 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -66,15 +66,37 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
     output.wait_to_read()
   return mod.get_outputs()
 
-def check_quantize(sym, data_shape, out_type, check_conv=True):
+def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape):
+  # save qsym to JSON file
+  qsym.save('quantized-symbol.json')
+  # save params
+  save_dict = {('arg:%s' % k): v.as_in_context(mx.current_context()) for k, v in qarg_params.items()}
+  save_dict.update({('aux:%s' % k): v.as_in_context(mx.current_context()) for k, v in qaux_params.items()})
+  mx.nd.save('quantized-0000.params', save_dict)
+  # load back with SymbolBlock
+  net = mx.gluon.SymbolBlock.imports('quantized-symbol.json', ['data'], 'quantized-0000.params')
+  net.collect_params().reset_ctx(ctx = mx.current_context())
+  net.hybridize()
+
+  data = mx.random.uniform(-1.0, 1.0, shape=data_shape)
+  net(data)
+
+def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=False):
   fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc')
-  sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
-  sym_sg = sym.get_backend_symbol("MKLDNN")
-  label_shape = (data_shape[0], 10)
-  mod = Module(symbol=sym)
-  mod.bind(for_training=False,
-           data_shapes=[('data', data_shape)],
-           label_shapes=[('softmax_label', label_shape)])
+  if gluon_forward == True:
+    sym = fc
+    sym_sg = fc.get_backend_symbol("MKLDNN")
+    mod = Module(symbol=sym, label_names=[])
+    mod.bind(for_training=False,
+            data_shapes=[('data', data_shape)])
+  else:
+    sym = mx.sym.SoftmaxOutput(data=fc, name='softmax')
+    sym_sg = sym.get_backend_symbol("MKLDNN")
+    label_shape = (data_shape[0], 10)
+    mod = Module(symbol=sym)
+    mod.bind(for_training=False,
+            data_shapes=[('data', data_shape)],
+            label_shapes=[('softmax_label', label_shape)])
   mod.init_params(mx.init.Normal(0.5))
   arg_params, aux_params = mod.get_params()
 
@@ -107,10 +129,13 @@ def check_quantize(sym, data_shape, out_type, check_conv=True):
   qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE")
   if check_conv:
     check_qsym_calibrated(qsym, out_type)
-  quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
-  for i in range(len(ref_out)):
-    assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
-  check_qsym_dummy_forward(qsym, batch, data_shape, label_shape)
+  if gluon_forward == True:
+    check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape)
+  else:
+    check_qsym_dummy_forward(qsym, batch, data_shape, label_shape)
+    quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape)
+    for i in range(len(ref_out)):
+      assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)
 
 
 @with_seed()
@@ -137,6 +162,7 @@ def check_fusion(sym, data_shape, attrs_op):
   # fp32 to int8
   for out_type in ('uint8', 'int8', 'auto'):
     check_quantize(sym, data_shape, out_type)
+    check_quantize(sym, data_shape, out_type, gluon_forward=True)
 
 def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)):
   for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs):
@@ -478,10 +504,13 @@ def test_pos_single_concat():
     for out_type in ('uint8', 'int8', 'auto'):
       net = single_concat(data_shape, 2, 1)
       check_quantize(net, data_shape, out_type, False)
+      check_quantize(net, data_shape, out_type, False, True)
       net = single_concat(data_shape, 4, 2)
       check_quantize(net, data_shape, out_type, False)
+      check_quantize(net, data_shape, out_type, False, True)
       net = single_concat(data_shape, 4, 3)
       check_quantize(net, data_shape, out_type, False)
+      check_quantize(net, data_shape, out_type, False, True)
 
 @with_seed()
 def test_neg_conv_bn():