You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by an...@apache.org on 2022/02/02 04:58:37 UTC
[tvm] branch main updated: [Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)
This is an automated email from the ASF dual-hosted git repository.
andrewzhaoluo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new efe662f [Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)
efe662f is described below
commit efe662fe66fbbd186cc9af9dddf61f3fe200f6bb
Author: Margaret Qian <ym...@gmail.com>
AuthorDate: Tue Feb 1 20:57:45 2022 -0800
[Relay][Pass] Add a relay pass to extract fake quantized ops (#10089)
* add relay pass to collect fake quantized ops
* add more tests
* more tests
* lint
* lint
* remove unused imports
* update comment
* lint
* reuse SubgraphExtractor and update test assertions
* remove print
* lint
* remove unneeded comment
Co-authored-by: Margaret Qian <mq...@octoml.ai>
---
python/tvm/relay/analysis/analysis.py | 16 +++
src/relay/analysis/extract_fake_quantized_ops.cc | 80 +++++++++++++
.../transforms/fake_quantization_to_integer.cc | 109 ++++++++---------
.../transforms/fake_quantization_to_integer.h | 54 +++++++++
.../test_analysis_extract_fake_quantized_ops.py | 133 +++++++++++++++++++++
5 files changed, 335 insertions(+), 57 deletions(-)
diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py
index b627005..3b38c07 100644
--- a/python/tvm/relay/analysis/analysis.py
+++ b/python/tvm/relay/analysis/analysis.py
@@ -352,6 +352,22 @@ def list_op_freqs(mod):
return _ffi_api.ExtractOperators(mod)
+def list_fake_quantized_op_freqs(mod):
+ """Pass to extract fake quantized op names and the frequency that they appear
+ in fake quantized regions of an IRModule.
+
+ Parameters
+ ----------
+ mod : tvm.IRModule
+
+ Returns
+ -------
+ ret : Dict[str, int]
+ Dict of fake quantized operator names to frequency
+ """
+ return _ffi_api.ExtractFakeQuantizedOps(mod)
+
+
def search_fc_transpose(expr):
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc
new file mode 100644
index 0000000..68cee85
--- /dev/null
+++ b/src/relay/analysis/extract_fake_quantized_ops.cc
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file extract_fake_quantized_ops.cc
+ * \brief Extract fake quantized operators from an IRModule
+ */
+#include <tvm/relay/expr.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/transform.h>
+
+#include "../transforms/fake_quantization_to_integer.h"
+
+namespace tvm {
+namespace relay {
+
+using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;
+
+class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor {
+ public:
+ Map<String, tvm::Integer> Extract(const IRModule& m) {
+ IRModule mod(m);
+ mod = transform::InferType()(mod);
+ VisitExpr(mod->Lookup("main"));
+
+ return fake_quantized_op_freqs_;
+ }
+
+ private:
+ using MixedModeVisitor::VisitExpr_;
+
+ void VisitExpr_(const CallNode* call_node) override {
+ if (call_node->op == quantize_op_) {
+ SubgraphExtractor extractor;
+ ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(call_node));
+
+ for (auto expr : subgraph) {
+ const Op op = Downcast<Op>(expr.as<CallNode>()->op);
+ if (op != dequantize_op_) {
+ if (fake_quantized_op_freqs_.find(op->name) != fake_quantized_op_freqs_.end()) {
+ fake_quantized_op_freqs_.Set(op->name,
+ int64_t(fake_quantized_op_freqs_.at(op->name)) + 1);
+ } else {
+ fake_quantized_op_freqs_.Set(op->name, 1);
+ }
+ }
+ }
+ }
+ }
+
+ Map<String, tvm::Integer> fake_quantized_op_freqs_;
+ const Op quantize_op_ = Op::Get("qnn.quantize");
+ const Op dequantize_op_ = Op::Get("qnn.dequantize");
+};
+
+Map<String, tvm::Integer> ExtractFakeQuantizedOpsPacked(const IRModule& mod) {
+ return ExtractFakeQuantizedOpsWrapper().Extract(mod);
+}
+
+TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps")
+ .set_body_typed(ExtractFakeQuantizedOpsPacked);
+
+} // namespace relay
+} // namespace tvm
diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc
index fa6a1a5..4273fc2 100644
--- a/src/relay/transforms/fake_quantization_to_integer.cc
+++ b/src/relay/transforms/fake_quantization_to_integer.cc
@@ -23,12 +23,15 @@
* to actual integer operations.
*/
-#include <tvm/ir/affine_type.h>
+#include "fake_quantization_to_integer.h"
+
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>
+#include <unordered_map>
+
namespace tvm {
namespace relay {
@@ -75,69 +78,61 @@ using AffineTypeMap = Map<Expr, AffineType>;
using FTVMFakeQuantizationToInteger =
runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;
-class SubgraphExtractor : public ExprVisitor {
- public:
- const ExprSet GetSubgraph(const Expr& expr) {
- VisitExpr(expr);
- ExprSet subgraph;
- if (is_fake_quantized_) {
- for (auto kv : this->visit_counter_) {
- if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
- if (call_node->op != quantize_op_) {
- subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
- }
+const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) {
+ VisitExpr(expr);
+ ExprSet subgraph;
+ if (is_fake_quantized_) {
+ for (auto kv : this->visit_counter_) {
+ if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
+ if (call_node->op != quantize_op_) {
+ subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
}
}
- return subgraph;
}
- const AffineTypeMap GetAffineTypes() { return affine_types_; }
- void VisitExpr(const Expr& expr) override {
- // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
- // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
- // abort the rewrite.
- if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
- expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
- expr.as<ConstantNode>() == nullptr) {
- DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
- << " a fake quantize region, aborting this rewrite";
- is_fake_quantized_ = false;
- } else {
- ExprVisitor::VisitExpr(expr);
- }
+ return subgraph;
+}
+const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; }
+void SubgraphExtractor::VisitExpr(const Expr& expr) {
+ // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
+ // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
+ // abort the rewrite.
+ if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
+ expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
+ expr.as<ConstantNode>() == nullptr) {
+ DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
+ << " a fake quantize region, aborting this rewrite";
+ is_fake_quantized_ = false;
+ } else {
+ ExprVisitor::VisitExpr(expr);
}
+}
- protected:
- void VisitExpr_(const CallNode* call_node) override {
- if (call_node->op == quantize_op_) {
- const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
- ICHECK(attrs != nullptr);
- // Only look at arg0 for quantize
- VisitExpr(call_node->args[0]);
- // Collect type of quantize ops
- affine_types_.Set(
- GetRef<Expr>(call_node),
- TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
- } else if (call_node->op == dequantize_op_) {
- const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
- ICHECK(attrs != nullptr);
- // Collect type of dequantize ops
- affine_types_.Set(
- GetRef<Expr>(call_node),
- TensorAffineType(call_node->args[1], call_node->args[2],
- call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
- attrs->axis));
- } else {
- // run normally on everything else.
- ExprVisitor::VisitExpr_(call_node);
- }
+void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {
+ const Op test_op = Downcast<Op>(call_node->op);
+ if (call_node->op == quantize_op_) {
+ const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
+ ICHECK(attrs != nullptr);
+ // Only look at arg0 for quantize
+ VisitExpr(call_node->args[0]);
+ // Collect type of quantize ops
+ affine_types_.Set(
+ GetRef<Expr>(call_node),
+ TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
+ } else if (call_node->op == dequantize_op_) {
+ const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
+ ICHECK(attrs != nullptr);
+ // Collect type of dequantize ops
+ affine_types_.Set(
+ GetRef<Expr>(call_node),
+ TensorAffineType(call_node->args[1], call_node->args[2],
+ call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
+ attrs->axis));
+ } else {
+ // run normally on everything else.
+ ExprVisitor::VisitExpr_(call_node);
}
-
- const Op quantize_op_ = Op::Get("qnn.quantize");
- const Op dequantize_op_ = Op::Get("qnn.dequantize");
- bool is_fake_quantized_ = true;
- AffineTypeMap affine_types_;
-};
+}
class SubgraphMutator : public ExprMutator {
public:
diff --git a/src/relay/transforms/fake_quantization_to_integer.h b/src/relay/transforms/fake_quantization_to_integer.h
new file mode 100644
index 0000000..1956f94
--- /dev/null
+++ b/src/relay/transforms/fake_quantization_to_integer.h
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/transforms/fake_quantization_to_integer.h
+ * \brief Extract subgraph of a fake quantized region.
+ */
+#ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
+#define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
+
+#include <tvm/ir/affine_type.h>
+#include <tvm/relay/expr_functor.h>
+
+#include <unordered_set>
+
+namespace tvm {
+namespace relay {
+
+class SubgraphExtractor : public ExprVisitor {
+ public:
+ const std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetSubgraph(const Expr& expr);
+ const Map<Expr, AffineType> GetAffineTypes();
+ void VisitExpr(const Expr& expr) override;
+
+ protected:
+ void VisitExpr_(const CallNode* call_node) override;
+
+ private:
+ const Op quantize_op_ = Op::Get("qnn.quantize");
+ const Op dequantize_op_ = Op::Get("qnn.dequantize");
+ bool is_fake_quantized_ = true;
+ Map<Expr, AffineType> affine_types_;
+};
+
+} // namespace relay
+} // namespace tvm
+
+#endif // TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py
new file mode 100644
index 0000000..54594a2
--- /dev/null
+++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test function extraction"""
+import tvm
+from tvm import relay
+
+
+def test_fake_quantize_conv():
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+ w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8")
+ zero = relay.const(0)
+
+ op = relay.op.nn.conv2d(
+ relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+ relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+ kernel_size=[5, 5],
+ )
+ op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"nn.conv2d": 1}
+
+
+def test_fake_quantize_dense():
+ x = relay.var("x", shape=[128, 64], dtype="int8")
+ w = relay.var("w", shape=[256, 64], dtype="int8")
+ zero = relay.const(0)
+
+ op = relay.op.nn.dense(
+ relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+ relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+ )
+ op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"nn.dense": 1}
+
+
+def test_fake_quantize_multiple_regions():
+ x = relay.var("x", shape=[128, 64], dtype="int8")
+ w = relay.var("w", shape=[256, 64], dtype="int8")
+ zero = relay.const(0)
+
+ op = relay.op.nn.dense(
+ relay.qnn.op.dequantize(x, relay.const(2.0), zero),
+ relay.qnn.op.dequantize(w, relay.const(0.5), zero),
+ )
+ op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+ op = relay.qnn.op.dequantize(op, relay.const(2.0), relay.const(114))
+ op = relay.op.nn.relu(op)
+ op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+ w2 = relay.var("w2", shape=[64, 256], dtype="int8")
+ op = relay.op.nn.dense(
+ relay.qnn.op.dequantize(op, relay.const(1.0), zero),
+ relay.qnn.op.dequantize(w2, relay.const(0.5), zero),
+ )
+ op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8")
+
+ # We expect to ignore this sigmoid op since it's just outside a fake
+ # quantized region
+ op = relay.op.sigmoid(op)
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"nn.dense": 2, "nn.relu": 1}
+
+
+def test_fake_quantize_maxpool():
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+
+ zero = relay.const(0)
+ x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
+ op = relay.op.nn.max_pool2d(x, [3, 3])
+ op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"nn.max_pool2d": 1}
+
+
+def test_fake_quantize_transpose_reshape():
+ x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8")
+
+ zero = relay.const(0)
+ x = relay.qnn.op.dequantize(x, relay.const(2.0), zero)
+ op = relay.op.transpose(x, [1, 0, 2, 3])
+ op = relay.op.reshape(op, [3, -1])
+ op = relay.qnn.op.quantize(op, relay.const(2.0), zero)
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"transpose": 1, "reshape": 1}
+
+
+def test_fake_quantize_concat():
+ zero = relay.const(0)
+ inputs = []
+ for i in range(4):
+ inputs.append(
+ relay.qnn.op.dequantize(
+ relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero
+ )
+ )
+ concat = relay.op.concatenate(inputs, axis=1)
+ op = relay.qnn.op.quantize(concat, relay.const(3.5), zero)
+
+ mod = tvm.IRModule.from_expr(op)
+ fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod)
+
+ assert dict(fake_quantized_op_freqs) == {"concatenate": 1}