You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/04/08 07:08:48 UTC

[GitHub] [incubator-tvm] windclarion commented on a change in pull request #5277: [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes

windclarion commented on a change in pull request #5277: [BYOC] Refine AnnotateTarget and MergeCompilerRegion Passes
URL: https://github.com/apache/incubator-tvm/pull/5277#discussion_r405303431
 
 

 ##########
 File path: src/relay/transforms/annotate_target.cc
 ##########
 @@ -19,131 +19,155 @@
 
 /*!
  * \file src/relay/transforms/annotate_target.cc
- * \brief Wraps a call with compiler_begin and compiler_end to indicate that
- * the op of this call node will use external compiler.
+ * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
+ * this expr should be handled by the external compiler.
  */
 
 #include <tvm/relay/attrs/annotation.h>
 #include <tvm/relay/expr_functor.h>
 #include <tvm/relay/op_attr_types.h>
 #include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
 
 namespace tvm {
 namespace relay {
 namespace annotate_target {
 
-// Cache compiler_begin op for equivalence check.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+const PackedFunc* begin_op = runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+const PackedFunc* end_op = runtime::Registry::Get("relay.op.annotation._make.compiler_end");
 
 // A helper class to insert annotation boundaries for a program region that will
 // be handled by a specific compiler.
 class AnnotateTargetWrapper : public ExprMutator {
  public:
-  explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}
+  explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
+
+  /*!
+   * \brief This function annotates a compiler end and a compiler begin to all arguments.
+   *
+   *  The compiler end is based on the arg target while the compiler begin is based on the given
+   *  target. If target is not given and all arguments are going to the same target, then we will
+   *  use that target; otherwise we use default for this op. Note that all arg exprs must be
+   *  available in op_expr_to_target before calling this function.
+   *
+   * \param args An array of arguments of the given node.
+   * \param target The target of the current node.
+   * \return A pair of target and annotated argument expressions.
+   */
+  std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
+                                                   const std::string& target = "") {
+    std::string ref_target = "";
+    Array<Expr> compiler_ends;
+    for (auto arg : args) {
+      if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
+        std::string arg_target = op_expr_to_target_[arg];
+        compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op));
+        if (ref_target == "") {
+          ref_target = arg_target;
+        } else if (ref_target != arg_target) {
+          ref_target = "default";
+        }
+      } else {
+        // Input vars.
+        compiler_ends.push_back(arg);
+      }
+    }
+
+    // Determine compiler begin target.
+    std::string op_target = (target == "") ? ref_target : target;
+
+    Array<Expr> compiler_begins;
+    for (const auto& end : compiler_ends) {
+      compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op));
+    }
 
-  Expr Annotate(const Expr& expr) {
-    return InsertEnd(Mutate(expr));
+    return {op_target, compiler_begins};
   }
 
-  bool IsSupported(const Expr& expr) {
-    if (expr->IsInstance<CallNode>()) {
-      Call call = Downcast<Call>(expr);
-      auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
-      if (call->op->IsInstance<OpNode>()) {
-        Op op = Downcast<Op>(call->op);
-        CHECK(op.defined());
-        if (fannotate.count(op)) {
-          return fannotate[op](call->attrs, call->args);
+  Expr InsertAnnotation(const Expr& expr, const std::string& target, const PackedFunc* ann_op) {
+    Expr new_op = (*ann_op)(expr, target);
+    new_op->checked_type_ = expr->checked_type_;
+    return new_op;
+  }
+
+  Expr VisitExpr_(const CallNode* cn) final {
+    // Supported targets for this node. The order implies the priority.
+    std::vector<std::string> supported_targets;
+
+    // Check which targets this op can be offloaded.
+    if (cn->op->IsInstance<OpNode>()) {
+      // TVM operators: Check target specific op checking function and add to supported_targets
+      // if it is supported.
+      Op op = Downcast<Op>(cn->op);
+      CHECK(op.defined());
+      for (const auto& target : this->targets_) {
+        auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
 
 Review comment:
   need add one line: if (!Op::HasAttr("target." + std::string(target))) return false;           before auto fannotate
   
   because for composite function, GetAttr maybe failed

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services