You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/02/02 04:58:37 UTC

[tvm] branch main updated: [Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)

This is an automated email from the ASF dual-hosted git repository.

andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new efe662f  [Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)
efe662f is described below

commit efe662fe66fbbd186cc9af9dddf61f3fe200f6bb
Author: Margaret Qian <ym...@gmail.com>
AuthorDate: Tue Feb 1 20:57:45 2022 -0800

    [Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)
    
    * add relay pass to collect fake quantized ops
    
    * add more tests
    
    * more tests
    
    * lint
    
    * lint
    
    * remove unused imports
    
    * update comment
    
    * lint
    
    * reuse SubgraphExtractor and update test assertions
    
    * remove print
    
    * lint
    
    * remove unneeded comment
    
    Co-authored-by: Margaret Qian <mq...@octoml.ai>
---
 python/tvm/relay/analysis/analysis.py              |  16 +++
 src/relay/analysis/extract_fake_quantized_ops.cc   |  80 +++++++++++++
 .../transforms/fake_quantization_to_integer.cc     | 109 ++++++++---------
 .../transforms/fake_quantization_to_integer.h      |  54 +++++++++
 .../test_analysis_extract_fake_quantized_ops.py    | 133 +++++++++++++++++++++
 5 files changed, 335 insertions(+), 57 deletions(-)

diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index b627005..3b38c07 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -352,6 +352,22 @@ def list_op_freqs(mod):
     return _ffi_api.ExtractOperators(mod)
 
 
+def list_fake_quantized_op_freqs(mod):
+    """Pass to extract fake quantized op names and the frequency that they appear
+    in fake quantized regions of an IRModule.
+
+    Parameters
+    ----------
+    mod : tvm.IRModule
+
+    Returns
+    -------
+    ret : Dict[str, int]
+        Dict of fake quantized operator names to frequency
+    """
+    return _ffi_api.ExtractFakeQuantizedOps(mod)
+
+
 def search_fc_transpose(expr):
     """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
 
diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc
new file mode 100644
index 0000000..68cee85
--- /dev/null
+++ b/src/relay/analysis/extract_fake_quantized_ops.cc
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_fake_quantized_ops.cc
+ * \brief Extract fake quantized operators from an IRModule
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../transforms/fake_quantization_to_integer.h"
+
+namespace tvm {
+namespace relay {
+
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor {
+ public:
+  Map<String, tvm::Integer> Extract(const IRModule& m) {
+    IRModule mod(m);
+    mod = transform::InferType()(mod);
+    VisitExpr(mod->Lookup("main"));
+
+    return fake_quantized_op_freqs_;
+  }
+
+ private:
+  using MixedModeVisitor::VisitExpr_;
+
+  void VisitExpr_(const CallNode* call_node) override {
+    if (call_node->op == quantize_op_) {
+      SubgraphExtractor extractor;
+      ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(call_node));
+
+      for (auto expr : subgraph) {
+        const Op op = Downcast<Op>(expr.as<CallNode>()->op);
+        if (op != dequantize_op_) {
+          if (fake_quantized_op_freqs_.find(op->name) != fake_quantized_op_freqs_.end()) {
+            fake_quantized_op_freqs_.Set(op->name,
+                                         int64_t(fake_quantized_op_freqs_.at(op->name)) + 1);
+          } else {
+            fake_quantized_op_freqs_.Set(op->name, 1);
+          }
+        }
+      }
+    }
+  }
+
+  Map<String, tvm::Integer> fake_quantized_op_freqs_;
+  const Op quantize_op_ = Op::Get("qnn.quantize");
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+};
+
+Map<String, tvm::Integer> ExtractFakeQuantizedOpsPacked(const IRModule& mod) {
+  return ExtractFakeQuantizedOpsWrapper().Extract(mod);
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps")
+    .set_body_typed(ExtractFakeQuantizedOpsPacked);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc
index fa6a1a5..4273fc2 100644
--- a/src/relay/transforms/fake_quantization_to_integer.cc
+++ b/src/relay/transforms/fake_quantization_to_integer.cc
@@ -23,12 +23,15 @@
  * to actual integer operations.
  */
 
-#include <tvm/ir/affine_type.h>
+#include "fake_quantization_to_integer.h"
+
 #include <tvm/relay/expr.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/qnn/attrs.h>
 #include <tvm/relay/transform.h>
 
+#include <unordered_map>
+
 namespace tvm {
 namespace relay {
 
@@ -75,69 +78,61 @@ using AffineTypeMap = Map<Expr, AffineType>;
 using FTVMFakeQuantizationToInteger =
     runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;
 
-class SubgraphExtractor : public ExprVisitor {
- public:
-  const ExprSet GetSubgraph(const Expr& expr) {
-    VisitExpr(expr);
-    ExprSet subgraph;
-    if (is_fake_quantized_) {
-      for (auto kv : this->visit_counter_) {
-        if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
-          if (call_node->op != quantize_op_) {
-            subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
-          }
+const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) {
+  VisitExpr(expr);
+  ExprSet subgraph;
+  if (is_fake_quantized_) {
+    for (auto kv : this->visit_counter_) {
+      if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
+        if (call_node->op != quantize_op_) {
+          subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
         }
       }
     }
-    return subgraph;
   }
-  const AffineTypeMap GetAffineTypes() { return affine_types_; }
-  void VisitExpr(const Expr& expr) override {
-    // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
-    // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
-    // abort the rewrite.
-    if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
-        expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
-        expr.as<ConstantNode>() == nullptr) {
-      DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
-                 << " a fake quantize region, aborting this rewrite";
-      is_fake_quantized_ = false;
-    } else {
-      ExprVisitor::VisitExpr(expr);
-    }
+  return subgraph;
+}
+const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; }
+void SubgraphExtractor::VisitExpr(const Expr& expr) {
+  // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
+  // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
+  // abort the rewrite.
+  if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
+      expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
+      expr.as<ConstantNode>() == nullptr) {
+    DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
+               << " a fake quantize region, aborting this rewrite";
+    is_fake_quantized_ = false;
+  } else {
+    ExprVisitor::VisitExpr(expr);
   }
+}
 
- protected:
-  void VisitExpr_(const CallNode* call_node) override {
-    if (call_node->op == quantize_op_) {
-      const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
-      ICHECK(attrs != nullptr);
-      // Only look at arg0 for quantize
-      VisitExpr(call_node->args[0]);
-      // Collect type of quantize ops
-      affine_types_.Set(
-          GetRef<Expr>(call_node),
-          TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
-    } else if (call_node->op == dequantize_op_) {
-      const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
-      ICHECK(attrs != nullptr);
-      // Collect type of dequantize ops
-      affine_types_.Set(
-          GetRef<Expr>(call_node),
-          TensorAffineType(call_node->args[1], call_node->args[2],
-                           call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
-                           attrs->axis));
-    } else {
-      // run normally on everything else.
-      ExprVisitor::VisitExpr_(call_node);
-    }
+void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {
+  const Op test_op = Downcast<Op>(call_node->op);
+  if (call_node->op == quantize_op_) {
+    const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
+    ICHECK(attrs != nullptr);
+    // Only look at arg0 for quantize
+    VisitExpr(call_node->args[0]);
+    // Collect type of quantize ops
+    affine_types_.Set(
+        GetRef<Expr>(call_node),
+        TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
+  } else if (call_node->op == dequantize_op_) {
+    const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
+    ICHECK(attrs != nullptr);
+    // Collect type of dequantize ops
+    affine_types_.Set(
+        GetRef<Expr>(call_node),
+        TensorAffineType(call_node->args[1], call_node->args[2],
+                         call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
+                         attrs->axis));
+  } else {
+    // run normally on everything else.
+    ExprVisitor::VisitExpr_(call_node);
   }
-
-  const Op quantize_op_ = Op::Get("qnn.quantize");
-  const Op dequantize_op_ = Op::Get("qnn.dequantize");
-  bool is_fake_quantized_ = true;
-  AffineTypeMap affine_types_;
-};
+}
 
 class SubgraphMutator : public ExprMutator {
  public:
diff --git a/src/relay/transforms/fake_quantization_to_integer.h b/src/relay/transforms/fake_quantization_to_integer.h
new file mode 100644
index 0000000..1956f94
--- /dev/null
+++ b/src/relay/transforms/fake_quantization_to_integer.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/fake_quantization_to_integer.h
+ * \brief Extract subgraph of a fake quantized region.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
+#define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
+
+#include <tvm/ir/affine_type.h>
+#include <tvm/relay/expr_functor.h>
+
+#include <unordered_set>
+
+namespace tvm {
+namespace relay {
+
+class SubgraphExtractor : public ExprVisitor {
+ public:
+  const std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetSubgraph(const Expr& expr);
+  const Map<Expr, AffineType> GetAffineTypes();
+  void VisitExpr(const Expr& expr) override;
+
+ protected:
+  void VisitExpr_(const CallNode* call_node) override;
+
+ private:
+  const Op quantize_op_ = Op::Get("qnn.quantize");
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+  bool is_fake_quantized_ = true;
+  Map<Expr, AffineType> affine_types_;
+};
+
+}  // namespace relay
+}  // namespace tvm
+
+#endif  // TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py
new file mode 100644
index 0000000..54594a2
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import tvm
+from tvm import relay
+
+
+def test_fake_quantize_conv():
+    x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+    w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8")
+    zero = relay.const(0)
+
+    op = relay.op.nn.conv2d(
+        relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+        relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+        kernel_size=[5, 5],
+    )
+    op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}
+
+
+def test_fake_quantize_dense():
+    x = relay.var("x", shape=[128, 64], dtype="int8")
+    w = relay.var("w", shape=[256, 64], dtype="int8")
+    zero = relay.const(0)
+
+    op = relay.op.nn.dense(
+        relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+        relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+    )
+    op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"nn.dense": 1}
+
+
+def test_fake_quantize_multiple_regions():
+    x = relay.var("x", shape=[128, 64], dtype="int8")
+    w = relay.var("w", shape=[256, 64], dtype="int8")
+    zero = relay.const(0)
+
+    op = relay.op.nn.dense(
+        relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+        relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+    )
+    op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+    op = relay.qnn.op.dequantize(op, relay.const(2.0), relay.const(114))
+    op = relay.op.nn.relu(op)
+    op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+    w2 = relay.var("w2", shape=[64, 256], dtype="int8")
+    op = relay.op.nn.dense(
+        relay.qnn.op.dequantize(op, relay.const(1.0), zero),
+        relay.qnn.op.dequantize(w2, relay.const(0.5), zero),
+    )
+    op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+    # We expect to ignore this sigmoid op since it's just outside a fake
+    # quantized region
+    op = relay.op.sigmoid(op)
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"nn.dense": 2, "nn.relu": 1}
+
+
+def test_fake_quantize_maxpool():
+    x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+
+    zero = relay.const(0)
+    x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
+    op = relay.op.nn.max_pool2d(x, [3, 3])
+    op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"nn.max_pool2d": 1}
+
+
+def test_fake_quantize_transpose_reshape():
+    x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+
+    zero = relay.const(0)
+    x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
+    op = relay.op.transpose(x, [1, 0, 2, 3])
+    op = relay.op.reshape(op, [3, -1])
+    op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"transpose": 1, "reshape": 1}
+
+
+def test_fake_quantize_concat():
+    zero = relay.const(0)
+    inputs = []
+    for i in range(4):
+        inputs.append(
+            relay.qnn.op.dequantize(
+                relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero
+            )
+        )
+    concat = relay.op.concatenate(inputs, axis=1)
+    op = relay.qnn.op.quantize(concat, relay.const(3.5), zero)
+
+    mod = tvm.IRModule.from_expr(op)
+    fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+    assert dict(fake_quantized_op_freqs) == {"concatenate": 1}