You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/11/30 01:50:45 UTC

[GitHub] [tvm] electriclilies commented on a change in pull request #9483: [Relay] Switch the VM to use the LowerTE pass instead of TECompiler::{Lower,LowerShapeFunc}.

electriclilies commented on a change in pull request #9483:
URL: https://github.com/apache/tvm/pull/9483#discussion_r758844510



##########
File path: include/tvm/runtime/vm/executable.h
##########
@@ -133,21 +133,20 @@ class Executable : public ModuleNode {
 
   /*!
    * \brief Returns a description of all the constants in the executable in human-readable
-   * format. Not intended to be machine readable, but rather to help with debugging and
-   * diffing generated code.
+   * format. Intended for debugging and diff-testing.

Review comment:
       did this get changed? 

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -722,18 +725,21 @@ class AOTExecutorCodegen : public MixedModeVisitor {
     for (const auto& kv : targets_) {
       VLOG(1) << "target: " << kv.second->ToDebugString();
     }
-    if (target_host_.defined()) {
-      VLOG(1) << "target host: " << target_host_->ToDebugString();
-    }
+    ICHECK(target_host_.defined());

Review comment:
       Can you add a message here?

##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -314,9 +316,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
   /*!
    * brief Create a function call
    * \param call_lowered_props The lowered function and the arguments to call it with
-   * \param call The call we got func and args from
+   * \param call_node The call we got func and args from
    */
-  void CreateFuncCall(CallLoweredProps call_lowered_props, Call call) {
+  void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {

Review comment:
       Change documentation to match parameter name. Also since we only use result_expr for the PackSid thingy it would be nice to add a note about why we need to pass it in in the documentation.
   
   (also note: I think CI should have caught this in sanity check?)

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -435,42 +497,56 @@ using AnalysisRemapping = std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual
  */
 class LowerTensorExprMutator : public DeviceAwareExprMutator {
  public:
-  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, const String& module_name,
-                         TECompiler compiler)
+  LowerTensorExprMutator(const IRModule& module, ProcessFn process_fn, String module_name,
+                         TECompiler compiler, SEScope host_se_scope)

Review comment:
       [placemarker]

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -215,24 +215,35 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
 
     if (memory_plan_.defined()) {
       // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
-      func_info =
-          relay::tec::UpdateMainWorkspaceSize(mod, targets_, memory_plan_->expr_to_storage_info);
+      tec::TargetMap tec_target_map;
+      for (const auto& pair : targets_) {
+        tec_target_map.emplace(static_cast<DLDeviceType>(pair.first->value), pair.second);
+      }

Review comment:
       Is this copy just to make sure UpdateMainWorkspaceSize doesn't have access to targets_?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);
+            // Mark function as 'extern' using the "ExternalSymbol" attribute.
+            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
+            module->Add(kv2.first, function);
+          }
+        }
+      }
+    }
+  }
+
   Array<tvm::runtime::Module> LowerExternalFunctions() {
     Array<tvm::runtime::Module> ret;
-    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
 
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       ICHECK(src_func.defined());
-      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
-        std::string code_gen_name = code_gen.value();
+      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
+      if (opt_compiler.defined()) {
+        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
+                                          << PrettyPrint(src_func);
+        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
+                << opt_symbol_name.value() << "' and function:" << std::endl
+                << PrettyPrint(src_func);
         cached_ext_funcs.push_back(it.first);
 
-        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
-                                      << AsText(src_func, false);
-
-        std::string sn = symbol_name.value();
-        if (cached_symbol.count(sn)) {
-          cached_symbol[sn] = code_gen_name;
-        } else {
-          ICHECK_NE(sn, code_gen_name)
-              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
-        }
-
-        std::string ext_name = "relay.ext." + code_gen_name;
+        std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
         ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        VLOG_CONTEXT << ext_name;
         runtime::Module ext_mod = (*pf)(src_func);
-
-        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
-        ret.push_back(ext_mod);
+        if (ext_mod.defined()) {
+          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
+            // It's possible the codegen yielded C or C++ tracked separately and thus the
+            // returned runtime module can be empty.

Review comment:
       Is there a way to check that the the codegen made C/C++ that is tracked separately? Would be nice to have a check for that here. Otherwise, maybe add some of the comment to the VLOG message?

##########
File path: include/tvm/relay/function.h
##########
@@ -148,6 +148,19 @@ Function WithFields(Function function, Optional<Array<Var>> opt_params = Optiona
                     Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
                     Optional<Span> opt_span = Optional<Span>());
 
+/*
+ * \brief Returns the Relay FunctionNode represented by base_func if it should be optimized.
+ * Otherwise returns nullptr.
+ *  - PrimFuncs are obviously not Relay Functions.

Review comment:
       To clarify, if the BaseFunc is a PrimFunc then it returns nullptr?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".

Review comment:
       Why do you retrieve from the cache instead of using the to_be_deleted list you just constructed?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);

Review comment:
       Hmm, just occurred to me that WithFields won't erase the attrs here if you pass {} in as attrs. I should add a note about that in the WithFields documentation

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -134,41 +144,78 @@ class TECompilerImpl : public TECompilerNode {
     return mod;
   }
 
+  void AddExterns(IRModule module) {
+    // Everything tagged with "Compiler" has been compiled, so remove those definitions.
+    std::vector<GlobalVar> to_be_deleted;
+    for (const auto& kv : module->functions) {
+      if (kv.second->GetAttr<String>(attr::kCompiler).defined()) {
+        to_be_deleted.push_back(kv.first);
+      }
+    }
+    for (const auto& global_var : to_be_deleted) {
+      module->Remove(global_var);
+    }
+    // HOWEVER we still need a Relay definition to go with those now external functions, so
+    // retrieve them from the cache and mark them with "ExternalSymbol".
+    for (const auto& kv1 : cache_) {
+      auto src_func = kv1.first->source_func;
+      ICHECK(src_func.defined());
+      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
+        for (const auto& kv2 : kv1.second->cached_func->funcs->functions) {
+          if (const auto* function_node = kv2.second.as<FunctionNode>()) {
+            // Abandon the existing function annotations.
+            Function function(function_node->params, function_node->body, function_node->ret_type,
+                              function_node->type_params, /*attrs=*/{}, function_node->span);
+            // Mark function as 'extern' using the "ExternalSymbol" attribute.
+            function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint);
+            module->Add(kv2.first, function);
+          }
+        }
+      }
+    }
+  }
+
   Array<tvm::runtime::Module> LowerExternalFunctions() {
     Array<tvm::runtime::Module> ret;
-    std::unordered_map<std::string, std::string> cached_symbol;
     std::vector<CCacheKey> cached_ext_funcs;
 
     for (const auto& it : cache_) {
       auto src_func = it.first->source_func;
       ICHECK(src_func.defined());
-      if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
-        auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
-        std::string code_gen_name = code_gen.value();
+      Optional<String> opt_compiler = src_func->GetAttr<String>(attr::kCompiler);
+      if (opt_compiler.defined()) {
+        Optional<String> opt_symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        ICHECK(opt_symbol_name.defined()) << "No external symbol is set for:" << std::endl
+                                          << PrettyPrint(src_func);
+        VLOG(1) << "using external codegen '" << opt_compiler.value() << "' for name '"
+                << opt_symbol_name.value() << "' and function:" << std::endl
+                << PrettyPrint(src_func);
         cached_ext_funcs.push_back(it.first);
 
-        auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
-                                      << AsText(src_func, false);
-
-        std::string sn = symbol_name.value();
-        if (cached_symbol.count(sn)) {
-          cached_symbol[sn] = code_gen_name;
-        } else {
-          ICHECK_NE(sn, code_gen_name)
-              << "Found duplicated symbol: " << sn << " for: " << code_gen_name;
-        }
-
-        std::string ext_name = "relay.ext." + code_gen_name;
+        std::string ext_name = "relay.ext." + opt_compiler.value();
         auto pf = tvm::runtime::Registry::Get(ext_name);
         ICHECK(pf) << "Failed to find the codegen tool for " << ext_name;
         // No need to keep compiler attribute at this point, functions have been
         // extracted for specific codegen.
         src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
+        VLOG_CONTEXT << ext_name;
         runtime::Module ext_mod = (*pf)(src_func);
-
-        ICHECK(ext_mod.defined()) << "No external runtime is generated.";
-        ret.push_back(ext_mod);
+        if (ext_mod.defined()) {
+          if (ext_mod->GetFunction(opt_symbol_name.value(), /*query_imports=*/true) == nullptr) {
+            // It's possible the codegen yielded C or C++ tracked separately and thus the
+            // returned runtime module can be empty.

Review comment:
       Or maybe log a warning.

##########
File path: src/relay/analysis/call_graph.cc
##########
@@ -64,9 +67,19 @@ void CallGraphNode::AddToCallGraph(const GlobalVar& gv, const Function& func) {
   // post-order visitor will visit each AST node of the current function to
   // figure out the dependencies between functions.
   PostOrderVisit(func, [&](const Expr& expr) {
-    if (const GlobalVarNode* gvn = expr.as<GlobalVarNode>()) {
-      auto callee = GetRef<GlobalVar>(gvn);
-      cg_node->AddCalledGlobal(LookupGlobalVar(callee));
+    // TODO(mbs): Cleanup shapes functions.
+    if (const auto* call_node = expr.as<CallNode>()) {
+      CallLoweredProps props = GetCallLoweredProps(call_node);
+      if (props.lowered_func.defined() && props.attrs.metadata.count("prim_shape_fn_var")) {

Review comment:
       Might be nice to add a shape_func field to CallLoweredProps, that way we don't need to extract prim_shape_fn_var from the props.attrs directly. 
   
   Then, this check would be like
   ```if (props.lowered_func.defined() && props.shape_func.defined())```

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -224,23 +271,28 @@ class TECompilerImpl : public TECompilerNode {
     }
     cur_ccache_key_ = key;
 
-    // No need to lower external functions for now. We will invoke the external
-    // codegen tool once and lower all functions together.
-    if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
-      auto ir_module = IRModule();
-      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
-      auto func_name = GetUniqueName(name_node.value(), &name_map_);
+    Optional<String> opt_compiler = key->source_func->GetAttr<String>(attr::kCompiler);
+    if (opt_compiler.defined()) {
+      // Don't compile now since we don't have anywhere to put the resulting runtime module.
+      // Instead place the original definition in the cache and wait for LowerExternalFunctions.
+      IRModule ir_module;
+      // Note that the source_func may already be bound to a global function in the module

Review comment:
       In the code you ICHECK that opt_global_symbol is defined, so I'd suggest changing "may" to "should be" or "is"




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org