You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/31 08:58:03 UTC

[incubator-mxnet] branch master updated: Post quantize property improvement (#20929)

This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f16c3c7  Post quantize property improvement (#20929)
f16c3c7 is described below

commit f16c3c77cb5ceddb5b8a8327e84a7d60c237a313
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Thu Mar 31 10:55:47 2022 +0200

    Post quantize property improvement (#20929)
    
    * Post quantize property improvement
    
    * Review suggestion
    
    * Sanity fix
    
    * Review change
---
 .../subgraph/dnnl/dnnl_post_quantize_property.h    | 52 +++++++++++-----------
 1 file changed, 27 insertions(+), 25 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 456a0d1..0a7439b 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -37,15 +37,19 @@
 namespace mxnet {
 namespace op {
 namespace {
-const std::set<std::string> support_req_fusion_op = {
-    "_contrib_quantized_elemwise_add",
-    "_contrib_quantized_elemwise_mul",
-    // "_contrib_quantized_npi_add" - to be added later on
-    "_sg_onednn_conv",
-    "_sg_onednn_fully_connected",
-    "_sg_onednn_selfatt_qk",
-    "_sg_onednn_selfatt_valatt",
-    "_sg_onednn_batch_dot"};
+bool SupportsRequantizeFusion(const Op* op) {
+  static const std::set<const Op*> support_requantize_fusion_ops = {
+      Op::Get("_contrib_quantized_elemwise_add"),
+      Op::Get("_contrib_quantized_elemwise_mul"),
+      // Op::Get("_contrib_quantized_npi_add") - to be added later on
+      Op::Get("_sg_onednn_conv"),
+      Op::Get("_sg_onednn_fully_connected"),
+      Op::Get("_sg_onednn_selfatt_qk"),
+      Op::Get("_sg_onednn_selfatt_valatt"),
+      Op::Get("_sg_onednn_batch_dot")};
+
+  return support_requantize_fusion_ops.count(op) > 0;
+}
 }  // namespace
 
 class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
@@ -62,18 +66,15 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
   bool float_output;
   SelectStatus status;
   std::vector<const BiDirectedNode*> matched_list;
-  std::set<std::string> support_requantize_fusion_op_name;
 
  public:
   explicit SgDNNLPostQuantizeSelector(const bool fuse_all, const bool float_output)
-      : fuse_all(fuse_all), float_output(float_output) {
-    support_requantize_fusion_op_name = support_req_fusion_op;
-  }
+      : fuse_all(fuse_all), float_output(float_output) {}
 
   bool Select(const BiDirectedNode& n) override {
     const nnvm::Node* raw_node = n.node;
-    if (fuse_all && raw_node->op() &&
-        support_requantize_fusion_op_name.count(raw_node->op()->name)) {
+
+    if (fuse_all && raw_node->op() && SupportsRequantizeFusion(raw_node->op())) {
       status = SelectStatus::kStart;
       matched_list.clear();
       matched_list.emplace_back(&n);
@@ -89,6 +90,10 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
   bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) override {
     const nnvm::Node* raw_node     = n.node;
     const nnvm::Node* raw_new_node = new_node.node;
+
+    static const std::set<const Op*> dequantize_fusion_unsupported_ops = {
+        Op::Get("_contrib_quantized_elemwise_add")};
+
     if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
         raw_new_node->is_variable())
       return false;
@@ -111,9 +116,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
           if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
             matched_list.emplace_back(&new_node);
             status = SelectStatus::kRequantize;
-            // For now there is no support for dequantize fusion for contrib_quantized_elemwise_add
-            // so with this operator we finish after finding requantize node:
-            if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) {
+            // For now there is no support for dequantize fusion for some operators
+            // so then we finish after finding requantize node:
+            if (dequantize_fusion_unsupported_ops.count(raw_node->op()) != 0) {
               status = SelectStatus::kSuccess;
             }
             return true;
@@ -170,13 +175,11 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
  private:
   bool fuse_all;
   bool float_output;
-  std::set<std::string> support_requantize_fusion_op_name;
 
  public:
   SgDNNLPostQuantizeProperty() {
-    fuse_all                          = dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
-    float_output                      = dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
-    support_requantize_fusion_op_name = support_req_fusion_op;
+    fuse_all     = dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
+    float_output = dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
   }
 
   static SubgraphPropertyPtr Create() {
@@ -196,7 +199,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
     DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
       if (node->is_variable())
         return;
-      if (node->op() && support_requantize_fusion_op_name.count(node->op()->name)) {
+      if (node->op() && SupportsRequantizeFusion(node->op())) {
         fuse_node = node;
       } else if (node->op() == Op::Get("_contrib_requantize")) {
         requantize_node = node;
@@ -213,8 +216,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
 
     // When only fused quantized operator and requantize, set min/max_cablib_range,
     // When fused quantized operator + requantize + dequantize, set dequantize flag to true.
-    if ((dequantize_node != nullptr) &&
-        (fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) {
+    if (dequantize_node != nullptr) {
       fuse_node->attrs.dict["enable_float_output"] = "True";
     } else {
       fuse_node->attrs.dict["min_calib_range"] =