You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by bg...@apache.org on 2022/03/31 08:58:03 UTC
[incubator-mxnet] branch master updated: Post quantize property improvement (#20929)
This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new f16c3c7 Post quantize property improvement (#20929)
f16c3c7 is described below
commit f16c3c77cb5ceddb5b8a8327e84a7d60c237a313
Author: DominikaJedynak <do...@intel.com>
AuthorDate: Thu Mar 31 10:55:47 2022 +0200
Post quantize property improvement (#20929)
* Post quantize property improvement
* Review suggestion
* Sanity fix
* Review change
---
.../subgraph/dnnl/dnnl_post_quantize_property.h | 52 +++++++++++-----------
1 file changed, 27 insertions(+), 25 deletions(-)
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 456a0d1..0a7439b 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -37,15 +37,19 @@
namespace mxnet {
namespace op {
namespace {
-const std::set<std::string> support_req_fusion_op = {
- "_contrib_quantized_elemwise_add",
- "_contrib_quantized_elemwise_mul",
- // "_contrib_quantized_npi_add" - to be added later on
- "_sg_onednn_conv",
- "_sg_onednn_fully_connected",
- "_sg_onednn_selfatt_qk",
- "_sg_onednn_selfatt_valatt",
- "_sg_onednn_batch_dot"};
+bool SupportsRequantizeFusion(const Op* op) {
+ static const std::set<const Op*> support_requantize_fusion_ops = {
+ Op::Get("_contrib_quantized_elemwise_add"),
+ Op::Get("_contrib_quantized_elemwise_mul"),
+ // Op::Get("_contrib_quantized_npi_add") - to be added later on
+ Op::Get("_sg_onednn_conv"),
+ Op::Get("_sg_onednn_fully_connected"),
+ Op::Get("_sg_onednn_selfatt_qk"),
+ Op::Get("_sg_onednn_selfatt_valatt"),
+ Op::Get("_sg_onednn_batch_dot")};
+
+ return support_requantize_fusion_ops.count(op) > 0;
+}
} // namespace
class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
@@ -62,18 +66,15 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
bool float_output;
SelectStatus status;
std::vector<const BiDirectedNode*> matched_list;
- std::set<std::string> support_requantize_fusion_op_name;
public:
explicit SgDNNLPostQuantizeSelector(const bool fuse_all, const bool float_output)
- : fuse_all(fuse_all), float_output(float_output) {
- support_requantize_fusion_op_name = support_req_fusion_op;
- }
+ : fuse_all(fuse_all), float_output(float_output) {}
bool Select(const BiDirectedNode& n) override {
const nnvm::Node* raw_node = n.node;
- if (fuse_all && raw_node->op() &&
- support_requantize_fusion_op_name.count(raw_node->op()->name)) {
+
+ if (fuse_all && raw_node->op() && SupportsRequantizeFusion(raw_node->op())) {
status = SelectStatus::kStart;
matched_list.clear();
matched_list.emplace_back(&n);
@@ -89,6 +90,10 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) override {
const nnvm::Node* raw_node = n.node;
const nnvm::Node* raw_new_node = new_node.node;
+
+ static const std::set<const Op*> dequantize_fusion_unsupported_ops = {
+ Op::Get("_contrib_quantized_elemwise_add")};
+
if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
raw_new_node->is_variable())
return false;
@@ -111,9 +116,9 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
matched_list.emplace_back(&new_node);
status = SelectStatus::kRequantize;
- // For now there is no support for dequantize fusion for contrib_quantized_elemwise_add
- // so with this operator we finish after finding requantize node:
- if (raw_node->op() == Op::Get("_contrib_quantized_elemwise_add")) {
+ // For now there is no support for dequantize fusion for some operators
+ // so then we finish after finding requantize node:
+ if (dequantize_fusion_unsupported_ops.count(raw_node->op()) != 0) {
status = SelectStatus::kSuccess;
}
return true;
@@ -170,13 +175,11 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
private:
bool fuse_all;
bool float_output;
- std::set<std::string> support_requantize_fusion_op_name;
public:
SgDNNLPostQuantizeProperty() {
- fuse_all = dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
- float_output = dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
- support_requantize_fusion_op_name = support_req_fusion_op;
+ fuse_all = dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
+ float_output = dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
}
static SubgraphPropertyPtr Create() {
@@ -196,7 +199,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable())
return;
- if (node->op() && support_requantize_fusion_op_name.count(node->op()->name)) {
+ if (node->op() && SupportsRequantizeFusion(node->op())) {
fuse_node = node;
} else if (node->op() == Op::Get("_contrib_requantize")) {
requantize_node = node;
@@ -213,8 +216,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
// When only fused quantized operator and requantize, set min/max_cablib_range,
// When fused quantized operator + requantize + dequantize, set dequantize flag to true.
- if ((dequantize_node != nullptr) &&
- (fuse_node->op() != Op::Get("_contrib_quantized_elemwise_add"))) {
+ if (dequantize_node != nullptr) {
fuse_node->attrs.dict["enable_float_output"] = "True";
} else {
fuse_node->attrs.dict["min_calib_range"] =