You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/02/15 16:08:54 UTC

[GitHub] [tvm] mbrookhart commented on a change in pull request #10239: Add a conversion of individual operations in FQ2I pass.

mbrookhart commented on a change in pull request #10239:
URL: https://github.com/apache/tvm/pull/10239#discussion_r807010832



##########
File path: src/relay/transforms/fake_quantization_to_integer.cc
##########
@@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
   const bool hard_fail_;
 };
 
+bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {

Review comment:
       Could we add some comments and advice about what ops should be included in this list?

##########
File path: src/relay/transforms/fake_quantization_to_integer.cc
##########
@@ -270,8 +293,233 @@ class FakeQuantizationRewriter : public MixedModeMutator {
   const bool hard_fail_;
 };
 
+bool is_op_enabled_for_optional_fq2i(const CallNode* call_node) {
+  const Op op = Downcast<Op>(call_node->op);
+  static auto fqfq = Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
+  static std::unordered_set<Op, tvm::ObjectHash, tvm::ObjectEqual> ops = {
+      Op::Get("reshape"),
+      Op::Get("squeeze"),
+      Op::Get("strided_slice"),
+      Op::Get("transpose"),
+      Op::Get("expand_dims"),
+      Op::Get("nn.max_pool2d"),
+      Op::Get("nn.batch_flatten"),
+      Op::Get("nn.depth_to_space"),
+      Op::Get("max"),
+      Op::Get("min"),
+      Op::Get("nn.avg_pool2d"),
+      Op::Get("nn.global_avg_pool2d"),
+      Op::Get("nn.bias_add"),
+      Op::Get("nn.conv2d"),
+      Op::Get("nn.conv2d_transpose"),
+      Op::Get("nn.dense"),
+      Op::Get("nn.batch_matmul"),
+      Op::Get("split"),
+      Op::Get("clip"),
+      Op::Get("nn.relu"),
+      Op::Get("nn.pad"),
+      Op::Get("broadcast_to"),
+      Op::Get("minimum"),
+      Op::Get("maximum")};
+
+  auto is_enabled = [&](const auto i) { return i == call_node->op; };
+  auto result = std::find_if(std::begin(ops), std::end(ops), is_enabled);
+  return result != ops.end() && fqfq.count(Downcast<Op>(op));
+}
+
+class OptionalSubgraphExtractor : public ExprVisitor {
+ public:
+  const ExprSet GetSubgraph(const Expr& expr) {
+    expr_call_node_ = expr.as<CallNode>();
+    ICHECK(expr_call_node_ != nullptr);
+    ICHECK(is_op_enabled_for_optional_fq2i(expr_call_node_));
+
+    VisitExpr(expr);
+
+    ExprSet subgraph;
+    if (is_fake_quantized_) {
+      for (auto kv : this->visit_counter_) {
+        if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
+          if (call_node != expr_call_node_) {
+            subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
+          }
+        }
+      }
+    }
+    return subgraph;
+  }
+  const AffineTypeMap GetAffineTypes() { return affine_types_; }
+  void VisitExpr(const Expr& expr) override {
+    // When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
+    // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
+    // abort the rewrite.
+    if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
+        expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
+        expr.as<ConstantNode>() == nullptr) {
+      DLOG(INFO) << "FakeQuantizationToInteger found a non - dataflow op inside a fake quantize "
+                    "region, aborting this rewrite";
+      is_fake_quantized_ = false;
+    } else {
+      ExprVisitor::VisitExpr(expr);
+    }
+  }
+
+ protected:
+  void VisitExpr_(const CallNode* call_node) override {
+    if (call_node->op == dequantize_op_) {
+      const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
+      ICHECK(attrs != nullptr);
+
+      affine_types_.Set(
+          GetRef<Expr>(call_node),
+          TensorAffineType(
+              call_node->args[1], call_node->args[2],
+              tvm::relay::transform::InferTypeLocal(call_node->args[0]).as<TensorTypeNode>()->dtype,
+              attrs->axis));
+    } else if (call_node == expr_call_node_) {
+      for (auto arg : call_node->args) {
+        VisitExpr(arg);
+      }
+    } else {
+      // run normally on everything else.
+      ExprVisitor::VisitExpr_(call_node);
+    }
+  }
+
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+  bool is_fake_quantized_ = true;
+  AffineTypeMap affine_types_;
+  const CallNode* expr_call_node_ = nullptr;
+};
+
+class OptionalSubgraphMutator : public ExprMutator {
+ public:
+  OptionalSubgraphMutator(ExprSet subgraph, AffineTypeMap affine_types, bool hard_fail)
+      : subgraph_(subgraph), affine_types_(affine_types), hard_fail_(hard_fail) {}
+
+  Expr MutateSubgraph(const Expr& expr) {
+    if (subgraph_.size() == 0) {
+      return expr;
+    }
+
+    quantize_node_ = expr.as<CallNode>();
+    ICHECK(quantize_node_);
+    ICHECK(is_op_enabled_for_optional_fq2i(quantize_node_));
+
+    for (auto node : subgraph_) {
+      const Op op = Downcast<Op>(node.as<CallNode>()->op);
+
+      if (node.as<CallNode>()->op != dequantize_op_) {
+        // Only modify the subgraph if we have translation
+        // rules for every op
+        if (hard_fail_) {
+          LOG(FATAL) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
+        } else {
+          DLOG(INFO) << "Found no rewrite rule for " << AsText(op, false) << std::endl;
+          return expr;
+        }
+      }
+    }
+    try {
+      return Mutate(expr);
+    } catch (std::exception& e) {
+      if (hard_fail_) {
+        throw e;
+      } else {
+        DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping" << expr << std::endl;
+        return expr;
+      }
+    }
+  }
+
+ protected:
+  Expr VisitExpr_(const CallNode* call_node) {
+    Expr out;
+    static auto fqfq =
+        Op::GetAttrMap<FTVMFakeQuantizationToInteger>("FTVMFakeQuantizationToInteger");
+
+    Op op = Downcast<Op>(call_node->op);
+    if (fqfq.count(op)) {
+      Expr expr;
+      if (op == dequantize_op_) {
+        expr = GetRef<Expr>(call_node);
+      } else {
+        expr = ExprMutator::VisitExpr_(call_node);
+      }
+      // Call the rewrite
+      Array<ObjectRef> vals = fqfq[op](expr, affine_types_);
+      // Save the outputs of the rewrite
+      ICHECK(vals.size() == 2)
+          << "got the wrong number of returned arguments from FTVMFakeQuantizationToInteger for "
+          << AsText(op, false);
+      out = Downcast<Expr>(vals[0]);
+
+      affine_types_.Set(out, Downcast<AffineType>(vals[1]));
+
+      if (call_node == quantize_node_) {
+        out = qnn::MakeDequantize(out, vals[1].as<TensorAffineTypeNode>()->scale,
+                                  vals[1].as<TensorAffineTypeNode>()->zero_point,
+                                  vals[1].as<TensorAffineTypeNode>()->axis);
+      }
+    } else {
+      ICHECK(false) << "When rewriting a fake quantized graph, found an invalid node "
+                    << AsText(GetRef<Expr>(call_node), false);
+    }
+    return out;
+  }
+
+  Expr VisitExpr_(const TupleNode* node) {
+    Expr expr = ExprMutator::VisitExpr_(node);
+    auto new_node = expr.as<TupleNode>();
+    Array<TensorAffineType> types;
+    for (Expr field : new_node->fields) {
+      ICHECK(affine_types_[field].as<TensorAffineTypeNode>());
+      types.push_back(Downcast<TensorAffineType>(affine_types_[field]));
+    }
+    affine_types_.Set(expr, TupleAffineType(types));
+    return expr;
+  }
+
+  Expr VisitExpr_(const TupleGetItemNode* node) {
+    Expr expr = ExprMutator::VisitExpr_(node);
+    auto tuple_type = affine_types_[expr.as<TupleGetItemNode>()->tuple].as<TupleAffineTypeNode>();
+    affine_types_.Set(expr, tuple_type->types[node->index]);
+    return expr;
+  }
+
+  ExprSet subgraph_;
+  AffineTypeMap affine_types_;
+  const bool hard_fail_;
+  const Op dequantize_op_ = Op::Get("qnn.dequantize");
+  const CallNode* quantize_node_ = nullptr;
+};
+
+class OptionalFakeQuantizationRewriter : public MixedModeMutator {
+ public:
+  explicit OptionalFakeQuantizationRewriter(bool hard_fail) : hard_fail_(hard_fail) {}
+
+ protected:
+  Expr Rewrite_(const CallNode* pre, const Expr& post) override {
+    if (const CallNode* call_node = post.as<CallNode>()) {
+      const Op op = Downcast<Op>(call_node->op);
+      if (is_op_enabled_for_optional_fq2i(call_node)) {
+        OptionalSubgraphExtractor extractor;
+        ExprSet subgraph = extractor.GetSubgraph(post);
+        AffineTypeMap affine_types = extractor.GetAffineTypes();
+        Expr out = OptionalSubgraphMutator(subgraph, affine_types, hard_fail_).MutateSubgraph(post);
+        return out;
+      }
+    }
+    return post;
+  }
+  const bool hard_fail_;
+};
+
 Expr FakeQuantizationToInteger(const Expr& expr, const IRModule& mod, bool hard_fail) {
-  return FakeQuantizationRewriter(hard_fail).Mutate(expr);
+  auto fq_expr = FakeQuantizationRewriter(hard_fail).Mutate(expr);
+  auto fq_inferred_expr = tvm::relay::InferType(fq_expr);
+  auto ofq_expr = OptionalFakeQuantizationRewriter(hard_fail).Mutate(fq_inferred_expr);
+  return ofq_expr;
 }

Review comment:
       I'm not sure how problematic this will be in non-QAT models, but would it make sense to add another bool to make the "Optional" part of the pass actually optional?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org