You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/30 19:42:09 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7518: [RFC]TECompiler: Staged refactor and removal of compile engine

tkonolige commented on a change in pull request #7518:
URL: https://github.com/apache/tvm/pull/7518#discussion_r661762009



##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -420,170 +387,57 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
     return fields;
   }
 
-  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& op_name,
-                                             const std::string& func_name, GraphAttrs attrs) {
+  bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
+    StorageInfo lit = GetStorageInfo(lhs);
+    StorageInfo rit = GetStorageInfo(rhs);
+    int64_t lhs_storage_id = lit->storage_ids[0];
+    int64_t rhs_storage_id = rit->storage_ids[0];
+    return lhs_storage_id == rhs_storage_id;
+  }
+
+  std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* op, const std::string& func_name,
+                                             GraphAttrs op_attrs) {
     std::vector<GraphNodeRef> inputs;
     for (auto arg : op->args) {
       auto res = VisitExpr(arg);
       for (auto nr : res) {
         inputs.push_back(nr);
       }
     }
-    auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
-    return AddNode(node, GetRef<Expr>(op));
-  }
-
-  bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
-    auto lit = storage_device_map_.find(lhs);
-    auto rit = storage_device_map_.find(rhs);
-    ICHECK(lit != storage_device_map_.end());
-    ICHECK(rit != storage_device_map_.end());
-    int64_t lhs_storage_id = ((*lit).second)[0][0]->value;
-    int64_t rhs_storage_id = ((*rit).second)[0][0]->value;
-    return lhs_storage_id == rhs_storage_id;
-  }
 
-  /*!
-   * \brief Obtain the Target from the device type.
-   * If homogenous compilation, this will return the only target.
-   * If heteregenous compilation, this will select associated using the targets_ Map.
-   *
-   * \param dev_type
-   * \return Target
-   */
-  Target GetTargetFromInteger(int64_t dev_type) {
-    if (targets_.size() == 1) {
-      // homogeneous execution.
-      const auto& it = targets_.begin();
-      return (*it).second;
-    } else {
-      // heterogeneous execution.
-      std::string call_dev_name;
-      if (dev_type == 0) {
-        call_dev_name = "llvm";
-      } else {
-        call_dev_name = runtime::DeviceName(dev_type);
-      }
-      if (targets_.count(dev_type) == 0) {
-        LOG(FATAL) << "No target is provided for device " << call_dev_name;
-      }
-      return targets_[dev_type];
+    /// An adapted version of the storage optimization for the time being.
+    bool reshape_only = false;
+    if (op->attrs.defined() && op->attrs.as<TIRCallAttrs>()) {
+      reshape_only = true;
     }
-  }
 
-  /*!
-   * \brief Update the function metadata for a given cached function and its relay
-   * primitive function.
-   *
-   * \param cfunc The cached function as provided the by the compile engine
-   * \param relay_func The source relay primitive function
-   * \param relay_target The target associated with relay primitive function
-   */
-  void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
-                              const Target& relay_target) {
-    auto fi_node = make_object<FunctionInfoNode>();
-    for (const auto& kv : cfunc->funcs->functions) {
-      auto primfunc = Downcast<tir::PrimFunc>(kv.second);
-      auto workspace_byte_alignment = relay_target->GetAttr<Integer>("workspace-byte-alignment")
-                                          .value_or(tvm::runtime::kDefaultWorkspaceAlignment);
-      Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
-      Target primfunc_target = relay_target;
-      if (primfunc->attrs->dict.count("target")) {
-        primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
-      }
-      fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
-      // Calculating size for I/O
-      for (auto const& param : primfunc->params) {
-        auto p_shape = primfunc->buffer_map[param]->shape;
-        int num_of_elements = 1;
-        for (const auto& dim_index_expr : p_shape) {
-          if (dim_index_expr->IsInstance<IntImmNode>()) {
-            num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
-          } else {
-            // If shape is dynamic, we cannot calculate workspace in compile time.
-            num_of_elements = 0;
-          }
-        }
-        int element_size = primfunc->buffer_map[param]->dtype.bytes();
-        fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
-      }
-      fi_node->constant_sizes.Set(primfunc_target, 0);
-      fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
-      fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
-    }
-    function_metadata_.Set(cfunc->func_name, FunctionInfo(fi_node));
-  }
-
-  std::vector<GraphNodeRef> VisitExpr_(const CallNode* op) override {
-    Expr expr = GetRef<Expr>(op);
-    Function func;
-    if (op->op.as<OpNode>()) {
-      LOG(FATAL) << "Operators should be transformed away; try applying"
-                 << "the fuse_ops transformation to the expression.";
-    } else if (op->op.as<GlobalVarNode>()) {
-      LOG(FATAL) << "Not implemented";
-    } else if (op->op.as<FunctionNode>()) {
-      func = GetRef<Function>(op->op.as<FunctionNode>());
-    } else {
-      LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
-    }
-    if (!func->HasNonzeroAttr(attr::kPrimitive)) {
-      LOG(FATAL) << "TVM only support calls to primitive functions "
-                 << "(i.e functions composed of fusable operator invocations)";
-    }
-
-    // Copy attrs from function into the graph node
-    // For now we only handle strings
-    GraphAttrs attrs;
-    for (auto p : func->attrs->dict) {
-      if (p.second.as<StringObj>()) {
-        attrs[p.first] = std::string(Downcast<String>(p.second));
-      }

Review comment:
       You don't want to remove this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org